From 0e1cc5ffd6fe7ca26074ea8e9ac96aed38a23ba6 Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Fri, 22 Sep 2023 15:20:48 +0800 Subject: [PATCH] Repo sync (#357) # Pull Request ## What problem does this PR solve? Issue Number: Fixed # ## Possible side effects? - Performance: - Backward compatibility: --- .github/ISSUE_TEMPLATE/bug_report.yml | 2 - libspu/compiler/core/core.cc | 1 + libspu/compiler/front_end/BUILD.bazel | 1 - libspu/compiler/front_end/hlo_importer.cc | 3 - libspu/compiler/passes/BUILD.bazel | 13 ++ libspu/compiler/passes/passes.h | 2 + libspu/compiler/passes/passes.td | 6 + libspu/compiler/passes/sort_lowering.cc | 97 +++++++++++ libspu/compiler/tests/sort_lowering.mlir | 11 ++ libspu/core/context.h | 4 +- libspu/core/ndarray_ref.h | 22 ++- libspu/core/trace.h | 2 +- libspu/device/pphlo/pphlo_executor.cc | 26 +++ libspu/device/pphlo/pphlo_verifier.h | 1 + libspu/dialect/pphlo_base_enums.td | 16 ++ libspu/dialect/pphlo_ops.td | 16 ++ libspu/kernel/hal/prot_wrapper.cc | 6 + libspu/kernel/hal/prot_wrapper.h | 2 + libspu/kernel/hal/sort.cc | 32 +++- libspu/kernel/hal/sort.h | 12 ++ libspu/kernel/hlo/sort.cc | 38 ++++- libspu/kernel/hlo/sort.h | 5 + sml/README.md | 88 ++++++++++ sml/decomposition/nmf.py | 12 +- sml/decomposition/tests/nmf_test.py | 199 ++++++++++------------ sml/development.md | 60 +++++++ sml/faq.md | 87 ++++++++++ sml/sml_develop.svg | 1 + sml/support_lists.md | 25 +++ 29 files changed, 659 insertions(+), 131 deletions(-) create mode 100644 libspu/compiler/passes/sort_lowering.cc create mode 100644 libspu/compiler/tests/sort_lowering.mlir create mode 100644 sml/development.md create mode 100644 sml/faq.md create mode 100644 sml/sml_develop.svg create mode 100644 sml/support_lists.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 51d932b1..f2d36609 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -110,5 +110,3 @@ body: description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. render: Shell - - diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc index 4e531a08..8a81ae9a 100644 --- a/libspu/compiler/core/core.cc +++ b/libspu/compiler/core/core.cc @@ -53,6 +53,7 @@ void Core::buildPipeline(mlir::PassManager *pm) { } optPM.addPass(mlir::pphlo::createDecomposeComparisonPass()); optPM.addPass(mlir::pphlo::createDecomposeMinMaxPass()); + optPM.addPass(mlir::pphlo::createSortLowering()); if (!options.disable_sqrt_plus_epsilon_rewrite()) { optPM.addPass(mlir::pphlo::createOptimizeSqrtPlusEps()); diff --git a/libspu/compiler/front_end/BUILD.bazel b/libspu/compiler/front_end/BUILD.bazel index 06cdbb17..c7e8c6f4 100644 --- a/libspu/compiler/front_end/BUILD.bazel +++ b/libspu/compiler/front_end/BUILD.bazel @@ -53,7 +53,6 @@ spu_cc_library( "@xla//xla/service:scatter_expander", "@xla//xla/service:slice_sinker", "@xla//xla/service:sort_simplifier", - "@xla//xla/service:stable_sort_expander", "@xla//xla/service:triangular_solve_expander", "@xla//xla/service:while_loop_constant_sinking", "@xla//xla/service:while_loop_simplifier", diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc index 21b18cff..977d9c98 100644 --- a/libspu/compiler/front_end/hlo_importer.cc +++ b/libspu/compiler/front_end/hlo_importer.cc @@ -47,7 +47,6 @@ #include "xla/service/scatter_expander.h" #include "xla/service/slice_sinker.h" #include "xla/service/sort_simplifier.h" -#include "xla/service/stable_sort_expander.h" #include "xla/service/triangular_solve_expander.h" #include "xla/service/tuple_simplifier.h" #include "xla/service/while_loop_constant_sinking.h" @@ -105,8 +104,6 @@ void runHloPasses(xla::HloModule *module) { pipeline.AddPass(); - pipeline.AddPass(); - // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); diff --git a/libspu/compiler/passes/BUILD.bazel b/libspu/compiler/passes/BUILD.bazel index 96088b94..913431bd 100644 --- a/libspu/compiler/passes/BUILD.bazel +++ b/libspu/compiler/passes/BUILD.bazel @@ -246,6 +246,18 @@ spu_cc_library( ], ) +spu_cc_library( + name = "sort_lowering", + srcs = ["sort_lowering.cc"], + hdrs = ["passes.h"], + deps = [ + ":pass_details", + "//libspu/dialect:pphlo_dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TransformUtils", + ], +) + spu_cc_library( name = "all_passes", hdrs = ["register_passes.h"], @@ -263,5 +275,6 @@ spu_cc_library( ":optimize_sqrt_plus_eps", ":reduce_truncation", ":rewrite_div_sqrt_patterns", + ":sort_lowering", ], ) diff --git a/libspu/compiler/passes/passes.h b/libspu/compiler/passes/passes.h index ddcc32ba..6fc54988 100644 --- a/libspu/compiler/passes/passes.h +++ b/libspu/compiler/passes/passes.h @@ -73,6 +73,8 @@ createOptimizeDenominatorWithBroadcast(); std::unique_ptr> createInsertDeallocationOp(); +std::unique_ptr> createSortLowering(); + } // namespace pphlo } // namespace mlir diff --git a/libspu/compiler/passes/passes.td b/libspu/compiler/passes/passes.td index 05055234..ae045690 100644 --- a/libspu/compiler/passes/passes.td +++ b/libspu/compiler/passes/passes.td @@ -98,3 +98,9 @@ def InsertDeallocation: Pass<"insert-deallocation", "func::FuncOp"> { let constructor = "createInsertDeallocationOp()"; let dependentDialects = ["pphlo::PPHloDialect"]; } + +def SortLowering: Pass<"sort-lowering", "func::FuncOp"> { + let summary = "Lower some simple sort to simple sort op"; + let constructor = "createSortLowering()"; + let dependentDialects = ["pphlo::PPHloDialect"]; +} diff --git a/libspu/compiler/passes/sort_lowering.cc b/libspu/compiler/passes/sort_lowering.cc new file mode 100644 index 00000000..c6a51828 --- /dev/null +++ b/libspu/compiler/passes/sort_lowering.cc @@ -0,0 +1,97 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "spdlog/spdlog.h" + +#include "libspu/compiler/passes/pass_details.h" +#include "libspu/compiler/passes/passes.h" +#include "libspu/dialect/pphlo_ops.h" +#include "libspu/dialect/pphlo_types.h" + +namespace mlir::pphlo { + +namespace { + +struct SortConversion : public OpRewritePattern { +public: + explicit SortConversion(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(SortOp op, + PatternRewriter &rewriter) const override { + auto &comp = op.getComparator(); + if (op->getNumOperands() == 1) { + // When there is only one operand, stable or not seems irrelevant + op.setIsStable(false); + } + + // If has a single instruction comparator, check if it's a simple sort. + if (comp.hasOneBlock() && + llvm::hasSingleElement(comp.front().without_terminator())) { + auto &inst = comp.front().front(); + // Single instruction comparator. + if (mlir::isa(inst) || mlir::isa(inst)) { + mlir::IntegerAttr direction; + if (mlir::isa(inst)) { + // descent + direction = rewriter.getI32IntegerAttr( + static_cast(SortDirection::DES)); + } else { + // ascent + direction = rewriter.getI32IntegerAttr( + static_cast(SortDirection::ASC)); + } + auto lhs_idx = + inst.getOperand(0).dyn_cast().getArgNumber(); + auto rhs_idx = + inst.getOperand(1).dyn_cast().getArgNumber(); + // FIXME: If the comparator is using operands other than the first one, + // we should just reorder operands instead of bailout + if (lhs_idx != 0 || rhs_idx != 1) { + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getOperands(), op.getDimensionAttr(), + direction); + return success(); + } + } + return failure(); + } +}; + +struct SortLowering : public SortLoweringBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateOwningPatterns(&patterns, &getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + +private: + static void populateOwningPatterns(RewritePatternSet *patterns, + MLIRContext *ctx) { + patterns->insert(ctx); + } +}; +} // namespace + +std::unique_ptr> createSortLowering() { + return std::make_unique(); +} + +} // namespace mlir::pphlo diff --git a/libspu/compiler/tests/sort_lowering.mlir b/libspu/compiler/tests/sort_lowering.mlir new file mode 100644 index 00000000..8b751a09 --- /dev/null +++ b/libspu/compiler/tests/sort_lowering.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-pphlo-opt --sort-lowering --split-input-file %s | FileCheck %s + +func.func @main(%arg0: tensor<10x!pphlo.sec>) -> tensor<10x!pphlo.sec> { + // CHECK: simple_sort + %0 = "pphlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor>, %arg2: tensor>): + %1 = "pphlo.less"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> + "pphlo.return"(%1) : (tensor>) -> () + }) {dimension = 0 : i64, is_stable = false} : (tensor<10x!pphlo.sec>) -> tensor<10x!pphlo.sec> + return %0 : tensor<10x!pphlo.sec> + } diff --git a/libspu/core/context.h b/libspu/core/context.h index 3b8b920d..7aacac0f 100644 --- a/libspu/core/context.h +++ b/libspu/core/context.h @@ -94,7 +94,9 @@ class KernelEvalContext final { Type, // type of type uint128_t, // ring constant int64_t, // - SignType // + SignType, // + std::vector, // for sort + absl::Span // for sort >; SPUContext* sctx_; diff --git a/libspu/core/ndarray_ref.h b/libspu/core/ndarray_ref.h index 7d3fb28f..126ad1ea 100644 --- a/libspu/core/ndarray_ref.h +++ b/libspu/core/ndarray_ref.h @@ -71,11 +71,29 @@ class NdArrayRef { NdArrayRef(const NdArrayRef& other) = default; NdArrayRef(NdArrayRef&& other) = default; NdArrayRef& operator=(const NdArrayRef& other) = default; + +#ifndef NDEBUG + // GCC 11.4 with -O1 is not happy with default assign operator when using + // std::reverse... + NdArrayRef& operator=(NdArrayRef&& other) noexcept { + if (this != &other) { + std::swap(this->buf_, other.buf_); + std::swap(this->eltype_, other.eltype_); + std::swap(this->shape_, other.shape_); + std::swap(this->strides_, other.strides_); + std::swap(this->offset_, other.offset_); + std::swap(this->use_fast_indexing_, other.use_fast_indexing_); + std::swap(this->fast_indexing_stride_, other.fast_indexing_stride_); + } + return *this; + } +#else NdArrayRef& operator=(NdArrayRef&& other) = default; +#endif bool operator==(const NdArrayRef& other) const { - return buf_ == other.buf_ && shape_ == other.shape_ && - strides_ == other.strides_ && offset_ == other.offset_; + return shape_ == other.shape_ && strides_ == other.strides_ && + offset_ == other.offset_ && buf_ == other.buf_; } bool operator!=(const NdArrayRef& other) const { return !(*this == other); } diff --git a/libspu/core/trace.h b/libspu/core/trace.h index cf2c9fe7..47bec4ea 100644 --- a/libspu/core/trace.h +++ b/libspu/core/trace.h @@ -37,7 +37,7 @@ inline std::ostream& operator<<(std::ostream& os, } inline std::ostream& operator<<(std::ostream& os, - const absl::Span& indices) { + absl::Span indices) { os << fmt::format("{{{}}}", fmt::join(indices, ",")); return os; } diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index 3c580dbf..88e73d39 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -572,6 +572,32 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, } } +void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, + mlir::pphlo::SimpleSortOp &op, const ExecutionOptions &opts) { + auto sort_dim = op.getDimension(); + std::vector inputs(op->getNumOperands()); + for (size_t idx = 0; idx < inputs.size(); ++idx) { + inputs[idx] = lookupValue(sscope, op->getOperand(idx), opts); + } + + kernel::hal::SortDirection direction; + if (op.getSortDirectionAttr().getInt() == + static_cast(mlir::pphlo::SortDirection::ASC)) { + direction = kernel::hal::SortDirection::Ascending; + } else if (op.getSortDirectionAttr().getInt() == + static_cast(mlir::pphlo::SortDirection::DES)) { + direction = kernel::hal::SortDirection::Descending; + } else { + SPU_THROW("Should not reach here"); + } + + auto ret = kernel::hlo::SimpleSort(sctx, inputs, sort_dim, direction); + + for (int64_t idx = 0; idx < op->getNumResults(); ++idx) { + addValue(sscope, op->getResult(idx), ret[idx], opts); + } +} + void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, mlir::pphlo::SelectAndScatterOp &op, const ExecutionOptions &opts) { diff --git a/libspu/device/pphlo/pphlo_verifier.h b/libspu/device/pphlo/pphlo_verifier.h index 9b007956..d988ef78 100644 --- a/libspu/device/pphlo/pphlo_verifier.h +++ b/libspu/device/pphlo/pphlo_verifier.h @@ -149,6 +149,7 @@ class PPHloVerifier { NO_VERIFY_DEFN(RealOp) NO_VERIFY_DEFN(ImagOp) NO_VERIFY_DEFN(ComplexOp) + NO_VERIFY_DEFN(SimpleSortOp) #undef NO_VERIFY_DEFN }; diff --git a/libspu/dialect/pphlo_base_enums.td b/libspu/dialect/pphlo_base_enums.td index 269486af..b609f326 100644 --- a/libspu/dialect/pphlo_base_enums.td +++ b/libspu/dialect/pphlo_base_enums.td @@ -31,4 +31,20 @@ def PPHLO_VisibilityAttr let genSpecializedAttr = 0; } +//===----------------------------------------------------------------------===// +// Sort direction enum definitions. +//===----------------------------------------------------------------------===// +def PPHLO_SORT_DIRECTION_ASC : I32EnumAttrCase<"ASC", 0>; +def PPHLO_SORT_DIRECTION_DES : I32EnumAttrCase<"DES", 1>; + +def PPHLO_SortDirection : I32EnumAttr<"SortDirection", + "Which mode to sort.", + [ + PPHLO_SORT_DIRECTION_ASC, + PPHLO_SORT_DIRECTION_DES + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::pphlo"; +} + #endif // SPU_DIALECT_PPHLO_BASE_ENUMS diff --git a/libspu/dialect/pphlo_ops.td b/libspu/dialect/pphlo_ops.td index c7618518..eb52637a 100644 --- a/libspu/dialect/pphlo_ops.td +++ b/libspu/dialect/pphlo_ops.td @@ -938,6 +938,22 @@ def PPHLO_SortOp CArg<"bool", "false">:$is_stable)>]; } +def PPHLO_SimpleSortOp + : PPHLO_Op<"simple_sort", [RecursiveMemoryEffects, SameOperandsAndResultShape]> { + let summary = "Sort operator"; + let description = [{ + Sorts the given `operands` at the given `dimension` with the given + `mode`. + }]; + let arguments = (ins + Variadic:$operands, + DefaultValuedAttr:$dimension, + PPHLO_SortDirection: $sort_direction + ); + + let results = (outs Variadic); +} + def PPHLO_ReverseOp : PPHLO_Op<"reverse", [Pure, SameOperandsAndResultType]> { let summary = "Reverse operator"; diff --git a/libspu/kernel/hal/prot_wrapper.cc b/libspu/kernel/hal/prot_wrapper.cc index 9cf0f51d..d850703d 100644 --- a/libspu/kernel/hal/prot_wrapper.cc +++ b/libspu/kernel/hal/prot_wrapper.cc @@ -108,6 +108,12 @@ Value _trunc_s(SPUContext* ctx, const Value& in, size_t bits, SignType sign) { return mpc::trunc_s(ctx, in, bits, sign); } +std::vector _sort_s(SPUContext* ctx, absl::Span x) { + SPU_TRACE_HAL_DISP(ctx, x.size()); + // FIXME(jimi): formalize mpc sort api + return dynDispatch>(ctx, "sort_a", x); +} + MAP_UNARY_OP(p2s) MAP_UNARY_OP(s2p) MAP_UNARY_OP(not_p) diff --git a/libspu/kernel/hal/prot_wrapper.h b/libspu/kernel/hal/prot_wrapper.h index ac5126cf..131f533e 100644 --- a/libspu/kernel/hal/prot_wrapper.h +++ b/libspu/kernel/hal/prot_wrapper.h @@ -82,6 +82,8 @@ Value _make_p(SPUContext* ctx, uint128_t init, const Shape& shape); Value _rand_p(SPUContext* ctx, const Shape& shape); Value _rand_s(SPUContext* ctx, const Shape& shape); +std::vector _sort_s(SPUContext* ctx, absl::Span x); + // NOLINTEND(readability-identifier-naming) } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/sort.cc b/libspu/kernel/hal/sort.cc index 5f924ada..de38aa1c 100644 --- a/libspu/kernel/hal/sort.cc +++ b/libspu/kernel/hal/sort.cc @@ -1,6 +1,9 @@ #include "libspu/kernel/hal/sort.h" +#include + #include "libspu/kernel/hal/polymorphic.h" +#include "libspu/kernel/hal/prot_wrapper.h" #include "libspu/kernel/hal/public_helper.h" #include "libspu/kernel/hal/ring.h" #include "libspu/kernel/hal/shape_ops.h" @@ -172,11 +175,38 @@ std::vector sort1d(SPUContext *ctx, for (auto const &input : inputs) { ret.push_back(input.clone()); } - // TODO(jimi): leave a special dispatch path for radix sort BitonicSort(ctx, cmp, ret); } return ret; } +std::vector simple_sort1d(SPUContext *ctx, + absl::Span inputs, + SortDirection direction) { + // Fall back to generic sort + SPU_ENFORCE(!inputs.empty(), "Inputs should not be empty"); + if (inputs[0].isPublic() || !ctx->hasKernel("sort_a")) { + auto ret = sort1d( + ctx, inputs, + [&](absl::Span cmp_inputs) { + if (direction == SortDirection::Ascending) { + return hal::less(ctx, cmp_inputs[0], cmp_inputs[1]); + } + if (direction == SortDirection::Descending) { + return hal::greater(ctx, cmp_inputs[0], cmp_inputs[1]); + } + SPU_THROW("Should not reach here"); + }, + inputs[0].vtype(), false); + return ret; + } else { + auto ret = _sort_s(ctx, inputs); + if (direction == SortDirection::Descending) { + std::reverse(ret.begin(), ret.end()); + } + return ret; + } +} + } // namespace spu::kernel::hal \ No newline at end of file diff --git a/libspu/kernel/hal/sort.h b/libspu/kernel/hal/sort.h index bc0788a9..b0000905 100644 --- a/libspu/kernel/hal/sort.h +++ b/libspu/kernel/hal/sort.h @@ -9,9 +9,21 @@ namespace spu::kernel::hal { using CompFn = std::function)>; +// simple sort direction +enum class SortDirection { + Ascending, + Descending, +}; + +// general sort1d with comparator std::vector sort1d(SPUContext *ctx, absl::Span inputs, const CompFn &cmp, Visibility comparator_ret_vis, bool is_stable); +// simple sort1d without comparator +std::vector simple_sort1d(SPUContext *ctx, + absl::Span inputs, + SortDirection direction); + } // namespace spu::kernel::hal \ No newline at end of file diff --git a/libspu/kernel/hlo/sort.cc b/libspu/kernel/hlo/sort.cc index c64ef25e..ea448482 100644 --- a/libspu/kernel/hlo/sort.cc +++ b/libspu/kernel/hlo/sort.cc @@ -20,7 +20,10 @@ namespace spu::kernel::hlo { -namespace { +namespace internal { + +using Sort1dFn = + std::function(absl::Span)>; // Given a & p are vectors, and p is a permutation. // let b = permute(a, p) where b[i] = a[p[i]] @@ -44,13 +47,9 @@ Index InversePermute(const Index &p) { return q; } -} // namespace - std::vector Sort(SPUContext *ctx, absl::Span inputs, - int64_t sort_dim, bool is_stable, - const hal::CompFn &comparator_body, - Visibility comparator_ret_vis) { + int64_t sort_dim, const Sort1dFn &sort_fn) { // sanity check. SPU_ENFORCE(!inputs.empty(), "Inputs should not be empty"); // put the to_sort dimension to last dimension. @@ -99,8 +98,7 @@ std::vector Sort(SPUContext *ctx, hal::reshape(ctx, hal::slice(ctx, input, {ni, 0}, {ni + 1, W}), {W})); } - sorted1d.push_back(hal::sort1d(ctx, input_i, comparator_body, - comparator_ret_vis, is_stable)); + sorted1d.push_back(sort_fn(input_i)); } // result is (M,shape) @@ -119,4 +117,28 @@ std::vector Sort(SPUContext *ctx, return results; } +} // namespace internal + +std::vector Sort(SPUContext *ctx, + absl::Span inputs, + int64_t sort_dim, bool is_stable, + const hal::CompFn &comparator_body, + Visibility comparator_ret_vis) { + auto sort_fn = [&](absl::Span input) { + return hal::sort1d(ctx, input, comparator_body, comparator_ret_vis, + is_stable); + }; + return internal::Sort(ctx, inputs, sort_dim, sort_fn); +} + +std::vector SimpleSort(SPUContext *ctx, + absl::Span inputs, + int64_t sort_dim, + hal::SortDirection direction) { + auto sort_fn = [&](absl::Span input) { + return hal::simple_sort1d(ctx, input, direction); + }; + return internal::Sort(ctx, inputs, sort_dim, sort_fn); +} + } // namespace spu::kernel::hlo diff --git a/libspu/kernel/hlo/sort.h b/libspu/kernel/hlo/sort.h index 26be69c7..e20d0d13 100644 --- a/libspu/kernel/hlo/sort.h +++ b/libspu/kernel/hlo/sort.h @@ -24,4 +24,9 @@ std::vector Sort(SPUContext* ctx, const hal::CompFn& comparator_body, Visibility comparator_ret_vis); +std::vector SimpleSort(SPUContext* ctx, + absl::Span inputs, + int64_t sort_dim, + hal::SortDirection direction); + } // namespace spu::kernel::hlo diff --git a/sml/README.md b/sml/README.md index e69de29b..4ae790d6 100644 --- a/sml/README.md +++ b/sml/README.md @@ -0,0 +1,88 @@ +# SML: Secure Machine Learning + +**SML** is a python module implementing machine learning algorithm with [JAX](https://github.com/google/jax), +which can do **secure** training and inferring under the magic of [SPU](https://github.com/secretflow/spu). + +Our vision is to establish a general-purpose privacy-preserving machine learning(PPML) library, +being a secure version of [scikit-learn](https://github.com/scikit-learn/scikit-learn). + +Normally, the APIs of our algorithms are designed to be as consistent as possible with scikit-learn. +However, due to safety considerations and certain limitations of the SPU, some APIs will undergo changes. +Detailed explanations will be provided for any differences in the doc. + +## Why not scikit-learn + +First, scikit-learn is built top on Numpy and SciPy, running on centralized mode. +So you must collect all data into one node, which can't protect the privacy of data. + +The implementations in scikit-learn are usually very efficient and valid, then why not we just "translate" it to MPC? + +The quick answer for this question is **accuracy** and **efficiency**. + +In PPML, we observe that most framework encodes floating-point to fixed-point number, +which parameterized by `field`(bitwidth of underlying integer) and `fxp_fraction_bits`(fractional part bitwidth), +greatly restricting the effective range and precision of floating-point numbers. +on other hand, The major determinant of computational overhead is determined by the MPC protocol, +so the origin cpu-friendly ops may have pool performance. + +### Our Solution + +So we establish a new library SML trying to bridge these gaps: + +1. accuracy: optimize and test the algorithm based on fixed-point number, +e.g. prefer high-precision ops(`rsqrt` rather than `1/sqrt`), +essential re-transform to accommodate the valid range of non-linear ops +(see [fxp pitfalls](../docs/development/fxp.ipynb)). +2. efficiency: use MPC-friendly op to replace CPU-friendly op, +e.g. use numeric approximation trick to avoid sophistic computation, prefer arithmetic ops to comparison ops. + +Of course, we also supply an easy-to-test toolbox for advanced developer +who wants to develop their own MPC program: + +1. `Simulator`: provide a fixed-point computation environment and run at high speed. +But it's unable to provide a real SPU performance environment, +the test results cannot reflect the actual performance of the algorithm. +2. `Emulator`: emulate on the real MPC protocol using multiple processes/Docker(coming soon), +and can provide effective performance results. + +So the **accuracy** can be proved if the algorithm pass the test of `simulator`, +and you should test the **efficiency** using `emulator`. + +> WARNING: currently, SML is undergoing rapid developments, +> so it is not recommended for direct use in production environments. + +## Installation + +First, you should clone the spu repo to your local disk: + +```bash +git clone https://github.com/secretflow/spu.git +``` + +Some [Prerequisites](../CONTRIBUTING.md#build) are required according to your system. +After all these installed, you can run any test like: + +```bash +# run kmeans simulation +# simulation: run program in single process +# used for correctness test +bazel run -c opt //sml/cluster/tests:kmeans_test + +# run kmeans emulation +# emulation: run program with multiple processes(LAN setting) +# or multiple dockers(WAN setting, will come soon) +# used for efficiency test. +bazel run -c opt //sml/cluster/emulations:kmeans_emul +``` + +## Algorithm Support lists + +See [support lists](./support_lists.md) for all our algorithms and features we support. + +## Development + +See [development](./development.md) if you would like to contribute to SML. + +## FAQ + +We collect some [FAQ](./faq.md), you can check it first before submitting an issue. diff --git a/sml/decomposition/nmf.py b/sml/decomposition/nmf.py index 680c3483..6776fe78 100644 --- a/sml/decomposition/nmf.py +++ b/sml/decomposition/nmf.py @@ -134,7 +134,7 @@ def __init__( self._random_matrixA = random_matrixA self._random_matrixB = random_matrixB self._update_H = True - self._components = None + self.components_ = None def fit(self, X): """Learn a NMF model for the data X. @@ -168,7 +168,7 @@ def fit(self, X): return self def transform(self, X, transform_iter=40): - assert self._components is not None, f"should fit before transform" + assert self.components_ is not None, f"should fit before transform" self._update_H = False self._max_iter = transform_iter W = self.fit_transform(X) @@ -193,7 +193,7 @@ def fit_transform(self, X): else: avg = jnp.sqrt(X.mean() / self._n_components) W = jnp.full((X.shape[0], self._n_components), avg, dtype=X.dtype) - H = self._components + H = self.components_ # compute the regularization parameters l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H = compute_regularization( @@ -210,7 +210,7 @@ def fit_transform(self, X): # compute the reconstruction error if self._update_H: - self._components = H + self.components_ = H self.reconstruction_err_ = _beta_divergence( X, W, H, self._beta_loss, square_root=True ) @@ -218,5 +218,5 @@ def fit_transform(self, X): return W def inverse_transform(self, X): - assert self._components is not None, f"should fit before inverse_transform" - return jnp.dot(X, self._components) + assert self.components_ is not None, f"should fit before inverse_transform" + return jnp.dot(X, self.components_) diff --git a/sml/decomposition/tests/nmf_test.py b/sml/decomposition/tests/nmf_test.py index 44ec7b8e..114bd443 100644 --- a/sml/decomposition/tests/nmf_test.py +++ b/sml/decomposition/tests/nmf_test.py @@ -18,6 +18,7 @@ import numpy as np from sklearn.decomposition import NMF as SklearnNMF + import spu.spu_pb2 as spu_pb2 import spu.utils.simulation as spsim @@ -28,130 +29,116 @@ class UnitTests(unittest.TestCase): - def test_nmf(self): + @classmethod + def setUpClass(cls): + print(" ========= start test of NMF package ========= \n") + cls.random_seed = 0 + np.random.seed(cls.random_seed) + # NMF must use FM128 now, for heavy use of non-linear & matrix operations config = spu_pb2.RuntimeConfig( protocol=spu_pb2.ProtocolKind.ABY3, field=spu_pb2.FieldType.FM128, fxp_fraction_bits=30, ) - sim = spsim.Simulator(3, config) - - # Test fit_transform - def proc1(X, random_matrixA, random_matrixB): - model = NMF( - n_components=n_components, - l1_ratio=l1_ratio, - alpha_W=alpha_W, - random_matrixA=random_matrixA, - random_matrixB=random_matrixB, + cls.sim = spsim.Simulator(3, config) + + # generate some dummy test datas + cls.test_data = np.random.randint(1, 100, (100, 10)) * 1.0 + n_samples, n_features = cls.test_data.shape + + # random matrix should be generated in plaintext. + cls.n_components = 5 + random_state = np.random.RandomState(cls.random_seed) + cls.random_A = random_state.standard_normal(size=(cls.n_components, n_features)) + cls.random_B = random_state.standard_normal(size=(n_samples, cls.n_components)) + + # test hyper-parameters settings + cls.l1_ratio = 0.1 + cls.alpha_W = 0.01 + + @classmethod + def tearDownClass(cls): + print(" ========= test of NMF package end ========= \n") + + def _nmf_test_main(self, plaintext=True, mode="uniform"): + # uniform means model is fitted by fit_transform method + # seperate means model is fitted by first fit then transform + assert mode in ["uniform", "seperate"] + + # must define here, because test may run simultaneously + model = ( + SklearnNMF( + n_components=self.n_components, + init='random', + random_state=self.random_seed, + l1_ratio=self.l1_ratio, + solver="mu", # sml only implement this solver now. + alpha_W=self.alpha_W, + ) + if plaintext + else NMF( + n_components=self.n_components, + l1_ratio=self.l1_ratio, + alpha_W=self.alpha_W, + random_matrixA=self.random_A, + random_matrixB=self.random_B, ) + ) - W = model.fit_transform(X) - H = model._components + def proc(x): + if mode == "uniform": + W = model.fit_transform(x) + else: + model.fit(x) + W = model.transform(x) + + H = model.components_ X_reconstructed = model.inverse_transform(W) err = model.reconstruction_err_ + return W, H, X_reconstructed, err - # Create a simple dataset and random_matrix - X = np.random.randint(1, 100, (1000, 10)) - X = np.array(X, dtype=float) - n_samples, n_features = X.shape - n_components = 5 - random_seed = 0 - random_state = np.random.RandomState(random_seed) - A = random_state.standard_normal(size=(n_components, n_features)) - B = random_state.standard_normal(size=(n_samples, n_components)) - l1_ratio = 0.1 - alpha_W = 0.01 - - # Run the simulation - W, H, X_reconstructed, err = spsim.sim_jax(sim, proc1)(X, A, B) - print("reconstruction_error: ", err) - - # sklearn - model = SklearnNMF( - n_components=n_components, - init='random', - random_state=random_seed, - l1_ratio=l1_ratio, - solver="mu", - alpha_W=alpha_W, + run_func = ( + proc + if plaintext + else spsim.sim_jax( + self.sim, + proc, + ) ) - W_Sklearn = model.fit_transform(X) - H_Sklearn = model.components_ - X_reconstructed_Sklearn = model.inverse_transform(W_Sklearn) - err = model.reconstruction_err_ - print("reconstruction_error_sklearn: ", err) - self.assertTrue(np.allclose(W_Sklearn, W, atol=5e-1)) - self.assertTrue(np.allclose(H_Sklearn, H, atol=5e-1)) - self.assertTrue( - np.allclose(X_reconstructed_Sklearn, X_reconstructed, atol=5e-1) + + return run_func(self.test_data) + + def test_nmf_uniform(self): + print("============== start test of nmf uniform ==============\n") + + W, H, X_reconstructed, err = self._nmf_test_main(False, "uniform") + W_sk, H_sk, X_reconstructed_sk, err_sk = self._nmf_test_main(True, "uniform") + + np.testing.assert_allclose(err, err_sk, rtol=1, atol=1e-1) + np.testing.assert_allclose(W, W_sk, rtol=1, atol=1e-1) + np.testing.assert_allclose(H, H_sk, rtol=1, atol=1e-1) + np.testing.assert_allclose( + X_reconstructed, X_reconstructed_sk, rtol=1, atol=1e-1 ) + print("============== nmf uniform test pass ==============\n") + def test_nmf_seperate(self): - config = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.ABY3, - field=spu_pb2.FieldType.FM128, - fxp_fraction_bits=30, - ) - sim = spsim.Simulator(3, config) - - # Test fit and transform - def proc2(X, random_matrixA, random_matrixB): - model = NMF( - n_components=n_components, - l1_ratio=l1_ratio, - alpha_W=alpha_W, - random_matrixA=random_matrixA, - random_matrixB=random_matrixB, - ) + print("============== start test of nmf seperate ==============\n") - model.fit(X) - W = model.transform(X, transform_iter=40) - H = model._components - X_reconstructed = model.inverse_transform(W) - return W, H, X_reconstructed - - # Create a simple dataset and random_matrix - X = np.random.randint(1, 100, (1000, 10)) - X = np.array(X, dtype=float) - n_samples, n_features = X.shape - n_components = 5 - random_seed = 0 - random_state = np.random.RandomState(random_seed) - A = random_state.standard_normal(size=(n_components, n_features)) - B = random_state.standard_normal(size=(n_samples, n_components)) - l1_ratio = 0.1 - alpha_W = 0.01 - - # Run the simulation_seperate - W_seperate, H_seperate, X_reconstructed_seperate = spsim.sim_jax(sim, proc2)( - X, A, B - ) - print("W_matrix_spu_seperate: ", W_seperate[:5, :5]) - print("H_matrix_spu_seperate: ", H_seperate[:5, :5]) - print("X_reconstructed_spu_seperate: ", X_reconstructed_seperate[:5, :5]) - - # sklearn_seperate - model = SklearnNMF( - n_components=n_components, - init='random', - random_state=random_seed, - l1_ratio=l1_ratio, - solver="mu", - alpha_W=alpha_W, - ) - model.fit(X) - W_Sklearn_seperate = model.transform(X) - H_Sklearn_seperate = model.components_ - X_reconstructed_Sklearn_seperate = model.inverse_transform(W_Sklearn_seperate) - print("W_matrix_sklearn_seperate: ", W_Sklearn_seperate[:5, :5]) - print("H_matrix_sklearn_seperate: ", H_Sklearn_seperate[:5, :5]) - print( - "X_reconstructed_sklearn_seperate: ", - X_reconstructed_Sklearn_seperate[:5, :5], + W, H, X_reconstructed, err = self._nmf_test_main(False, "seperate") + W_sk, H_sk, X_reconstructed_sk, err_sk = self._nmf_test_main(True, "seperate") + + np.testing.assert_allclose(err, err_sk, rtol=1, atol=1e-1) + np.testing.assert_allclose(W, W_sk, rtol=1, atol=1e-1) + np.testing.assert_allclose(H, H_sk, rtol=1, atol=1e-1) + np.testing.assert_allclose( + X_reconstructed, X_reconstructed_sk, rtol=1, atol=1e-1 ) + print("============== nmf seperate test pass ==============\n") + if __name__ == "__main__": unittest.main() diff --git a/sml/development.md b/sml/development.md new file mode 100644 index 00000000..d4994668 --- /dev/null +++ b/sml/development.md @@ -0,0 +1,60 @@ +# Development + +We welcome developers of all skill levels to contribute their expertise. +There are many ways to contribute to SML including reporting a bug, improving the documentation and contributing new algorithm. +Of course, if you have any suggestion or feature request, feel free to open an [issue](https://github.com/secretflow/spu/issues). + +## Submitting a bug report + +If you want to submit an issue, please do your best to follow these guidelines which will make it easier and quicker to provide you with good feedback: + +- Contains a **short reproducible** code snippet, so anyone can reproduce the bug easily +- If an exception is raised, please provide **the full traceback**. +- including your operating system type, version of JAX, SPU(or commit id) + +## Contributing code + +![sml develop paradiam](./sml_develop.svg) + +> 1. To avoid duplicating work, it is highly advised that you search through the issue tracker and the PR list. +> If in doubt about duplicated work, or if you want to work on a non-trivial feature, +> it's **recommended** to first open an issue in the issue tracker to get some feedbacks from core developers. +> 2. Some essential [documents](https://www.secretflow.org.cn/docs/spu/latest/en-US) about SPU are highly recommended. +[This](../docs/tutorials/develop_your_first_mpc_application.ipynb) is a good first tutorial for new developers, +[pitfall](../docs/development/fxp.ipynb) will be a cheatsheet when you come across numerical problem. + +The preferred way to contribute to SML is to fork the main repository, then submit a “pull request” (PR). + +1. Create a GitHub account if you do not have one. +2. Fork the [project repository](https://github.com/secretflow/spu), +your can refer to [this](https://docs.github.com/en/get-started/quickstart/fork-a-repo) for more details. +3. Following the instructions on [CONTRIBUTING](../CONTRIBUTING.md), installing the prerequisites and running tests successfully. +4. Develop the feature on **your feature branch** on your computer, +using [Git](https://docs.github.com/en/get-started/quickstart/set-up-git) to do the version control. +5. Following [Before Pull Request](<./development.md#Before Pull Request>) to place or test your codes, +[these](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) + to create a pull request from your fork. +6. Committers do code review and then merge. + +## Before Pull Request + +When finishing your coding work, you are supposed to do some extra work before pulling request. + +1. **Make sure your code is up-to-date**: It is often helpful to keep your local feature branch **synchronized** with +the latest changes of the main SPU repository. +2. **Place your codes properly**: Generally speaking, for every algorithm, at least 3 files are needed +(e.g. for kmeans, check [PR](https://github.com/secretflow/spu/pull/277/files) as an example). + - `kmeans.py`: implementation of algorithm or new features, it should be a **"jit-able"** program which run correctly in plaintext + (same or near to output from scikit-learn). + - `kmeans_test.py`: a unittest python file, in which you test your program with **simulator**, then you should report the behavior + (like correctness or error rate) under MPC setting. + - `kmeans_emul.py`: similar to the above file, except you will test program with **emulator**, + then you can get sense of efficiency under different MPC protocols. +3. **Other things**: there are still some small fixes to do. + - **Add copyright**: see [this](<../CONTRIBUTING.md#Contributor License Agreement>) for details. + - **Add necessary doc**: your implementation may only have part features, or some changes have been made for limitation of both JAX and SPU, + you **MUST** describe these things explicitly! + - **Add/change bazel file**: currently, we adopt [bazel](https://github.com/bazelbuild/bazel) to manage our module, + so you might need add some [python rules](https://bazel.build/reference/be/python) in `BUILD.bazel`. + - **Format all your files**: using [buildifier](https://github.com/bazelbuild/buildtools/tree/master/buildifier) to format bazel file, + [black](https://github.com/psf/black) to format python file, [isort](https://github.com/PyCQA/isort) to sort the python imports. diff --git a/sml/faq.md b/sml/faq.md new file mode 100644 index 00000000..8b05d3b9 --- /dev/null +++ b/sml/faq.md @@ -0,0 +1,87 @@ +# FAQ + +1. How can I know the supported operations in SPU? + + **Ans**: check [jax numpy status](https://www.secretflow.org.cn/docs/spu/latest/en-US/reference/np_op_status) or [xla status](https://www.secretflow.org.cn/docs/spu/latest/en-US/reference/xla_status); + or you can direct test it using simulator or emulator. + +2. How to adjust **field** and **fxp_fraction_bits** to improve precision? + + **Ans**: For simulator, you can create a `RuntimeConfig` Object and then pass to `Simulator`. + + ```python + # for simulator + config = spu_pb2.RuntimeConfig( + protocol=spu_pb2.ProtocolKind.ABY3, + field=spu_pb2.FieldType.FM128, # change filed size here + fxp_fraction_bits=30, # change fxp here + ) + sim = spsim.Simulator(3, config) + ``` + + For emulator, you can define a config, e.g. named `spu_128.json`, + + ```json + { + "id": "outsourcing.3pc", + "nodes": { + "node:0": "127.0.0.1:9920", + "node:1": "127.0.0.1:9921", + "node:2": "127.0.0.1:9922", + "node:3": "127.0.0.1:9923", + "node:4": "127.0.0.1:9924" + }, + "devices": { + "SPU": { + "kind": "SPU", + "config": { + "node_ids": [ + "node:0", + "node:1", + "node:2" + ], + "spu_internal_addrs": [ + "127.0.0.1:9930", + "127.0.0.1:9931", + "127.0.0.1:9932" + ], + "runtime_config": { + "protocol": "ABY3", + "field": "FM128", + "fxp_fraction_bits": 30, + "enable_pphlo_profile": true, + "enable_hal_profile": true + } + } + }, + "P1": { + "kind": "PYU", + "config": { + "node_id": "node:3" + } + }, + "P2": { + "kind": "PYU", + "config": { + "node_id": "node:4" + } + } + } + } + ``` + + Then, in python file, you can set up an emulator, + + ```python + conf_path = "spu_128.json" # path of json file defined above + mode = emulation.Mode.MULTIPROCESS + emulator = emulation.Emulator(conf_path, mode, bandwidth=300, latency=20) + emulator.up() + ``` + +3. Why the program I write runs correctly in plaintext, but behaves differently under MPC? + + **Ans**: it depends. + - Huge error: you can check whether **overflow/underflow** is happened(often occurs when linear algebra ops such as `jax.numpy.linalg.*` are used), + whether incidentally use **floating-point random generator** in SPU. + - Mild error: you can switch to **larger ring** and more **fxp**. diff --git a/sml/sml_develop.svg b/sml/sml_develop.svg new file mode 100644 index 00000000..cbb79088 --- /dev/null +++ b/sml/sml_develop.svg @@ -0,0 +1 @@ +Start upWhetherduplicateYJust use itNImplement withJAX and testin plaintextSimulationTestNYAccuracyPassEmulationTestNYEfficiencyPassImprove numericalprecisionOptimize based oncost modelPull Requests \ No newline at end of file diff --git a/sml/support_lists.md b/sml/support_lists.md new file mode 100644 index 00000000..39b448ed --- /dev/null +++ b/sml/support_lists.md @@ -0,0 +1,25 @@ +# Algorithm Support lists + +The table below shows the capabilities currently available in SML. +In general, the following features are rarely(partly) supported in SML: + +- **Early stop** for training or iterating algorithm: We do not want to reveal any intermediate information. +- Manual set of **random seed**: SPU can't handle randomness of float properly, so if random value(matrix) is needed, +user should pass it as a parameter(such as `rsvd`, `NMF`) +- **Data inspection** like counting the number of label, re-transforming the data or label won't be done. +(So we may assume a "fixed" format for input or just tell the number of classes as a parameter) +- **single-sample SGD** not implemented for the latency consideration, MiniBatch-SGD(which we just call it `sgd` in sml) will replace it. +- Jax's Ops like `eigh`, `svd` can't run in SPU directly: `svd` implemented now is expensive and can't handle matrix that is not column full-rank matrix. + +| Algorithm | category | Supported features | Notes | +|:--------------------:|:-------------:|:---------------------------------------------------------------------------------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| KMEANS | cluster | init=`random` , algorithm=`lloyd` only | only run algo once for efficiency | +| PCA | decomposition | 1. `power_iteration` method(not used in scikit-learn) supported
2. `rsvd` method | 1. if method=`power_iteration`, then cov matrix will be computed first
2.`rsvd` is very unstable under fixedpoint setting even in `FM128`, so only small data is supported. | +| NMF | decomposition | init=`random`, solver=`mu`, beta_loss=`frobenius` only | | +| Logistic | linear model | 1. `sgd` solver only
2.only L2 regularization supported | 1. `sigmoid` will be evaluated approximately | +| Perceptron | linear model | 1. all regularization methods
2.patience-based early stop | 1. this early stop will not cut down the training time, it just forces the update of parameters stop | +| Ridge | linear model | 1. `svd` and `cholesky` solver only | | +| SgdClassifier | linear model | 1. linear regression and logistic regression only
L2 regularization supported only | 1. `sigmoid` will be evaluated approximately | +| Gaussian Naive Bayes | naive_bayes | 1. not support manual set of priors | | +| KNN | neighbors | 1.`brute` algorithm only
`uniform` and `distance` weights supported | 1. KD-tree or Ball-tree can't improve the efficiency in MPC setting | +| roc_auc_score | metric | 1.binary only | |