diff --git a/CHANGELOG.md b/CHANGELOG.md index 52a5d94b..e5cbeb99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ - [Feature] Support half type - [Feature] Add Psi Progress - [Feature] Add SineOp/CosineOp support +- [Feature] Add complex support ## 20230705 diff --git a/docs/reference/xla_status.md b/docs/reference/xla_status.md index e1cf948d..605835c7 100644 --- a/docs/reference/xla_status.md +++ b/docs/reference/xla_status.md @@ -7,206 +7,192 @@ List of XLA(mhlo-mlir) Ops that SPU supports: General limitation with SPU: * Dynamic shape is not supported - * Complex number is not supported - * SPU only supports fixed-point numbers, so no-finite is not supported + * SPU uses fixed-point numbers to simulate floating point, so nonfinite numbers are not supported ### XLA nullary ops -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `constant` | fully | Always yields a public value -| `iota` | fully | Always yields a public value -| `dynamic_iota` | no | -| `create_token` | no | +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `constant` | yes | Always yields a public value +| `iota` | yes | Always yields a public value +| `dynamic_iota` | no | +| `create_token` | no | -Count: Total = 4, fully supported = 2 ### XLA unary element-wise ops -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `abs` | fully | -| `cbrt` | no | -| `ceil` | fully | -| `convert` | fully | -| `count_leading_zeros`| no | -| `cosine` | fully | -| `exponential` | fully | -| `exponential_minus_one`| fully | -| `floor` | fully | -| `imag` | no | -| `is_finite` | no | -| `log` | fully | -| `log_plus_one` | fully | -| `logistic` | fully | -| `not` | fully | -| `negate` | fully | -| `popcnt` | not | -| `real` | not | -| `round_nearest_afz`| not | -| `rsqrt` | fully | -| `sign` | partial | -| `sine` | fully | -| `sqrt` | fully | -| `tanh` | fully | - -Count: Total = 24, fully supported = 16, partial = 1 +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `abs` | yes | +| `cbrt` | no | +| `ceil` | yes | +| `convert` | yes | +| `count_leading_zeros`| no | +| `cosine` | yes | +| `exponential` | yes | +| `exponential_minus_one`| yes | +| `floor` | yes | +| `imag` | yes | +| `is_finite` | no | +| `log` | yes | +| `log_plus_one` | yes | +| `logistic` | yes | +| `not` | yes | +| `negate` | yes | +| `popcnt` | not | +| `real` | yes | +| `round_nearest_afz`| not | +| `rsqrt` | yes | +| `sign` | partial | +| `sine` | yes | +| `sqrt` | yes | +| `tanh` | yes | + ### XLA binary element-wise ops -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `add` | fully | -| `atan2` | no | -| `complex` | no | -| `compare` | fully | -| `divide` | fully | -| `maximum` | fully | -| `minimum` | fully | -| `multiply` | fully | -| `power` | fully | -| `remainder` | fully | -| `shift_left` | partial | rhs must be a public -| `shift_right_arithmetic` | partial | rhs must be a public -| `shift_right_logical` | partial | rhs must be a public -| `subtract` | fully | - -Count: Total = 14, fully supported = 9, partial = 3 +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `add` | yes | +| `atan2` | no | +| `complex` | yes | +| `compare` | yes | +| `divide` | yes | +| `maximum` | yes | +| `minimum` | yes | +| `multiply` | yes | +| `power` | yes | +| `remainder` | yes | +| `shift_left` | partial | rhs must be a public +| `shift_right_arithmetic` | partial | rhs must be a public +| `shift_right_logical` | partial | rhs must be a public +| `subtract` | yes | + ### XLA binary logical element-wise ops -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `and` | fully | -| `or` | fully | -| `xor` | fully | +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `and` | yes | +| `or` | yes | +| `xor` | yes | -Count: Total = 3, fully supported = 3 ### XLA communication ops -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `infeed` | no | -| `outfeed` | no | -| `send` | no | -| `recv` | no | +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `infeed` | no | +| `outfeed` | no | +| `send` | no | +| `recv` | no | -Count: Total = 4, fully supported = 0 ### XLA parallelism related ops -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `replica_id` | no | +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `replica_id` | no | -Count: Total = 1, fully supported = 0 ### XLA control flow ops -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `after_all` | no | -| `if` | fully | -| `case` | no | -| `while` | partial | condition region must return a public scalar -| `all_gather` | no | -| `all_reduce` | no | -| `reduce_scatter` | no | -| `all_to_all` | no | -| `reduce` | fully | inherits limitations from reduce function - -Count: Total = 9, fully supported = 2, partial = 1 +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `after_all` | no | +| `if` | yes | +| `case` | no | +| `while` | partial | condition region must return a public scalar +| `all_gather` | no | +| `all_reduce` | no | +| `reduce_scatter` | no | +| `all_to_all` | no | +| `reduce` | yes | inherits limitations from reduce function + ### XLA tuple ops -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `get_tuple_element` | fully | -| `tuple` | fully | +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `get_tuple_element` | yes | +| `tuple` | yes | -Count: Total = 2, fully supported = 2 ### XLA other ops -| Op Name | supported(fully/partial/no) | notes | +| Op Name | supported(yes/partial/no) | notes | | :------------: | :-------------------------: | ----------- | -| `slice` | fully | -| `dynamic-slice`| fully | -| `dynamic-update-slice`| fully | -| `batch_norm_grad`| fully | Rely on XLA's batchnorm_expander pass -| `batch_norm_inference`| fully | Rely on XLA's batchnorm_expander pass -| `batch_norm_training` | fully | Rely on XLA's batchnorm_expander pass -| `bitcast_convert` | partial | Only supports convert to type of same size -| `broadcast` | fully | -| `broadcast_in_dim` | fully | -| `dynamic_broadcast_in_dim` | no | -| `cholesky` | fully | Rely on CholeskyExpander pass -| `clamp` | fully | -| `concatenate` | fully | -| `collective_permute` | no | -| `convolution` | fully | -| `copy` | no | -| `cross-replica-sum` | no | -| `custom_call` | no | -| `dot` | fully | -| `dot_general` | fully | -| `einsum` | fully | -| `unary_einsum` | fully | -| `fft` | no | -| `gather` | fully | -| `get_dimension_size` | no | -| `map` | fully | Rely on XLA's MapInliner pass -| `reshape` | fully | -| `dynamic_reshape` | no | -| `scatter` | no | -| `select` | fully | -| `select_and_scatter` | fully | -| `set_dimension_size` | no | -| `sort` | fully | -| `reverse` | fully | -| `pad` | fully | -| `trace` | no | -| `transpose` | fully | -| `triangular_solve` | fully | Rely on XLA's TriangularSolverExpander pass -| `reduce_window`| fully | -| `return` | fully | -| `torch_index_select` | no | -| `optimization_barrier` | no | - -Count: Total = 42, fully supported = 28, partial = 1 +| `slice` | yes | +| `dynamic-slice`| yes | +| `dynamic-update-slice`| yes | +| `batch_norm_grad`| yes | Rely on XLA's batchnorm_expander pass +| `batch_norm_inference`| yes | Rely on XLA's batchnorm_expander pass +| `batch_norm_training` | yes | Rely on XLA's batchnorm_expander pass +| `bitcast_convert` | partial | Only supports convert to type of same size +| `broadcast` | yes | +| `broadcast_in_dim` | yes | +| `dynamic_broadcast_in_dim` | no | +| `cholesky` | yes | Rely on CholeskyExpander pass +| `clamp` | yes | +| `concatenate` | yes | +| `collective_permute` | no | +| `convolution` | yes | +| `copy` | no | +| `cross-replica-sum` | no | +| `custom_call` | no | +| `dot` | yes | +| `dot_general` | yes | +| `einsum` | yes | +| `unary_einsum` | yes | +| `fft` | no | +| `gather` | yes | +| `get_dimension_size` | no | +| `map` | yes | Rely on XLA's MapInliner pass +| `reshape` | yes | +| `dynamic_reshape` | no | +| `scatter` | no | +| `select` | yes | +| `select_and_scatter` | yes | +| `set_dimension_size` | no | +| `sort` | yes | +| `reverse` | yes | +| `pad` | yes | +| `trace` | no | +| `transpose` | yes | +| `triangular_solve` | yes | Rely on XLA's TriangularSolverExpander pass +| `reduce_window`| yes | +| `return` | yes | +| `torch_index_select` | no | +| `optimization_barrier` | no | + ### XLA RNG ops -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `rng_uniform` | partial | Bound [a, b) must all be public scalar, result is also a public tensor -| `rng_normal` | no | -| `rng_bit_generator` | no | +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `rng_uniform` | yes | Bound [a, b) must all be public scalar, result is also a public tensor +| `rng_normal` | no | +| `rng_bit_generator` | no | -Count: Total = 3, fully supported = 0, partial = 1 ### XLA quantize op -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `dequantize` | no | +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `dequantize` | no | -Count: Total = 1, fully supported = 0, partial = 0 ### XLA miscellaneous ops -| Op Name | supported(fully/partial/no) | notes | -| :------------: | :-------------------------: | ----------- | -| `fusion` | no | -| `bitcast` | no | Internal op to XLA/GPU -| `reduce_precision` | no | -| `real_dynamic_slice` | no | -| `dynamic_pad` | no | -| `dynamic_gather` | no | -| `dynamic_conv` | no | -| `print` | no | -| `compute_reshape_shape` | no | -| `cstr_reshapable` | no | - -Count: Total = 10, fully supported = 0, partial = 0 +| Op Name | supported(yes/partial/no) | notes | +| :------------: | :-----------------------: | ----------- | +| `fusion` | no | +| `bitcast` | no | Internal op to XLA/GPU +| `reduce_precision` | no | +| `real_dynamic_slice` | no | +| `dynamic_pad` | no | +| `dynamic_gather` | no | +| `dynamic_conv` | no | +| `print` | no | +| `compute_reshape_shape` | no | +| `cstr_reshapable` | no | diff --git a/examples/python/ml/ml_test.py b/examples/python/ml/ml_test.py index f9ae8a10..3c2b8d18 100644 --- a/examples/python/ml/ml_test.py +++ b/examples/python/ml/ml_test.py @@ -244,9 +244,10 @@ def suite(): suite.addTest(UnitTests('test_ss_xgb')) suite.addTest(UnitTests('test_stax_mnist_classifier')) suite.addTest(UnitTests('test_stax_nn')) + suite.addTest(UnitTests('test_save_and_load_model')) + # should put JAX tests above suite.addTest(UnitTests('test_tf_experiment')) suite.addTest(UnitTests('test_torch_experiment')) - suite.addTest(UnitTests('test_save_and_load_model')) return suite diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc index 3681db40..60d5ebc0 100644 --- a/libspu/compiler/front_end/fe.cc +++ b/libspu/compiler/front_end/fe.cc @@ -96,6 +96,7 @@ void FE::buildFrontEndPipeline(mlir::PassManager *pm, const std::string &args) { pm->addPass(mlir::mhlo::createExpandHloTuplesPass()); auto &optPM = pm->nest(); + optPM.addPass(mlir::mhlo::createLowerComplexPass()); optPM.addPass(mlir::mhlo::createLegalizeEinsumToDotGeneralPass()); optPM.addPass(mlir::mhlo::createLegalizeGeneralDotPass()); optPM.addPass(mlir::mhlo::createSinkConstantsToControlFlowPass()); diff --git a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc index 8ecfd621..a50bd843 100644 --- a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc +++ b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc @@ -75,7 +75,8 @@ class HloToPPHloTypeConverter : public TypeConverter { Type oriElmTy = type.getElementType(); Type newElmTy; if (oriElmTy.isa<::mlir::FloatType>() || - oriElmTy.isa<::mlir::IntegerType>()) { + oriElmTy.isa<::mlir::IntegerType>() || + oriElmTy.isa<::mlir::ComplexType>()) { newElmTy = ::mlir::pphlo::UnsetType::get(oriElmTy); } else { newElmTy = oriElmTy; @@ -1531,6 +1532,9 @@ struct HloLegalizeToPPHlo .insert, + HloToPPHloOpConverter, + HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, diff --git a/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h b/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h index 86650227..4aa3f8a5 100644 --- a/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h +++ b/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h @@ -94,6 +94,9 @@ MAP_HLO_TO_PPHLO(SubtractOp) MAP_HLO_TO_PPHLO(TanhOp) MAP_HLO_TO_PPHLO(TransposeOp) MAP_HLO_TO_PPHLO(XorOp) +MAP_HLO_TO_PPHLO(RealOp) +MAP_HLO_TO_PPHLO(ImagOp) +MAP_HLO_TO_PPHLO(ComplexOp) MAP_HLO_TO_PPHLO_DIFF_NAME(BroadcastInDimOp, BroadcastOp) diff --git a/libspu/compiler/tests/hlo_to_pphlo_complex.mlir b/libspu/compiler/tests/hlo_to_pphlo_complex.mlir new file mode 100644 index 00000000..a02004ee --- /dev/null +++ b/libspu/compiler/tests/hlo_to_pphlo_complex.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-pphlo-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_PUBLIC --split-input-file %s | FileCheck %s + +func.func @main(%arg0: tensor<3xcomplex>) -> tensor<3xcomplex> { + // CHECK: pphlo.real + %0 = stablehlo.real %arg0 : (tensor<3xcomplex>) -> tensor<3xf32> + // CHECK: pphlo.imag + %1 = stablehlo.imag %arg0 : (tensor<3xcomplex>) -> tensor<3xf32> + // CHECK: pphlo.complex + %2 = stablehlo.complex %0, %1 : tensor<3xcomplex> + return %2 : tensor<3xcomplex> + } diff --git a/libspu/core/type.cc b/libspu/core/type.cc index 25c64dc8..ab19ce22 100644 --- a/libspu/core/type.cc +++ b/libspu/core/type.cc @@ -64,6 +64,8 @@ Type F64 = makePtType(PT_F64); Type I128 = makePtType(PT_I128); Type U128 = makePtType(PT_U128); Type BOOL = makePtType(PT_BOOL); +Type CF32 = makePtType(PT_CF32); +Type CF64 = makePtType(PT_CF64); bool isFloatTy(const Type& type) { if (!type.isa()) { diff --git a/libspu/core/type.h b/libspu/core/type.h index faa3b0d5..c9076943 100644 --- a/libspu/core/type.h +++ b/libspu/core/type.h @@ -330,6 +330,8 @@ extern Type F32; extern Type F64; extern Type I128; extern Type U128; +extern Type CF32; +extern Type CF64; class RingTy : public TypeImpl { using Base = TypeImpl; diff --git a/libspu/core/type_util.cc b/libspu/core/type_util.cc index 4ffd4087..33113868 100644 --- a/libspu/core/type_util.cc +++ b/libspu/core/type_util.cc @@ -106,6 +106,7 @@ size_t SizeOf(PtType ptt) { case PT_INVALID: return 0; FOREACH_PT_TYPES(CASE); + FOREACH_COMPLEX_PT_TYPES(CASE); default: SPU_THROW("unknown size of {}", ptt); } diff --git a/libspu/core/type_util.h b/libspu/core/type_util.h index 40474f3b..7a3efae3 100644 --- a/libspu/core/type_util.h +++ b/libspu/core/type_util.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include "fmt/format.h" @@ -88,6 +89,10 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype); FN(PT_U128, uint128_t, U128) \ FN(PT_BOOL, bool, I1) +#define FOREACH_COMPLEX_PT_TYPES(FN) \ + FN(PT_CF32, std::complex, CF32) \ + FN(PT_CF64, std::complex, CF64) + #define FOREACH_PT_TYPES(FN) \ FOREACH_INT_PT_TYPES(FN) \ FOREACH_FLOAT_PT_TYPES(FN) diff --git a/libspu/core/value.cc b/libspu/core/value.cc index bd1fd1c2..75306b4e 100644 --- a/libspu/core/value.cc +++ b/libspu/core/value.cc @@ -41,6 +41,11 @@ Visibility getVisibilityFromType(const Type& ty) { Value::Value(NdArrayRef data, DataType dtype) : data_(std::move(data)), dtype_(dtype) {} +Value::Value(NdArrayRef real, NdArrayRef imag, DataType dtype) + : data_(std::move(real)), imag_(std::move(imag)), dtype_(dtype) { + SPU_ENFORCE(data_.eltype() == imag_->eltype()); +} + Visibility Value::vtype() const { return getVisibilityFromType(storage_type()); }; @@ -66,7 +71,7 @@ size_t Value::chunksCount(size_t max_chunk_size) const { ValueProto Value::toProto(size_t max_chunk_size) const { SPU_ENFORCE(max_chunk_size > 0); - SPU_ENFORCE(dtype_ != DT_INVALID && vtype() != VIS_INVALID); + SPU_ENFORCE(dtype_ != DT_INVALID && vtype() != VIS_INVALID, "{}", *this); ValueProto ret; @@ -74,7 +79,7 @@ ValueProto Value::toProto(size_t max_chunk_size) const { if (size == 0) { return; } - ret.chunks.reserve(num_chunks); + ret.chunks.reserve(ret.chunks.size() + num_chunks); for (size_t i = 0; i < num_chunks; i++) { size_t chunk_size = std::min(max_chunk_size, size - i * max_chunk_size); @@ -92,15 +97,21 @@ ValueProto Value::toProto(size_t max_chunk_size) const { const size_t num_chunks = chunksCount(max_chunk_size); - if (data_.isCompact()) { - build_chunk(data_.data(), numel() * data_.elsize(), num_chunks); - } else { - // Make a compact clone - auto copy = data_.clone(); - SPU_ENFORCE(copy.isCompact(), "Must be a compact copy."); - build_chunk(copy.data(), copy.buf()->size(), num_chunks); - } + auto array_to_chunks = [&](const NdArrayRef& a) { + if (a.isCompact()) { + build_chunk(a.data(), numel() * a.elsize(), num_chunks); + } else { + // Make a compact clone + auto copy = a.clone(); + SPU_ENFORCE(copy.isCompact(), "Must be a compact copy."); + build_chunk(copy.data(), copy.buf()->size(), num_chunks); + } + }; + array_to_chunks(data_); + if (imag_) { + array_to_chunks(*imag_); + } ret.meta.CopyFrom(toMetaProto()); return ret; @@ -111,6 +122,7 @@ ValueMetaProto Value::toMetaProto() const { ValueMetaProto proto; proto.set_data_type(dtype_); + proto.set_is_complex(isComplex()); proto.set_visibility(vtype()); for (const auto& d : shape()) { proto.mutable_shape()->add_dims(d); @@ -121,6 +133,24 @@ ValueMetaProto Value::toMetaProto() const { Value Value::fromProto(const ValueProto& value) { const auto& meta = value.meta; + if (meta.is_complex()) { + // real + ValueMetaProto partial = value.meta; + partial.set_is_complex(false); + ValueProto partial_proto; + partial_proto.meta = partial; + auto n = value.chunks.size() / 2; + std::copy_n(value.chunks.begin(), n, + std::back_inserter(partial_proto.chunks)); + auto rv = fromProto(partial_proto); + + partial_proto.chunks.clear(); + std::copy_n(value.chunks.begin() + n, n, + std::back_inserter(partial_proto.chunks)); + auto iv = fromProto(partial_proto); + return Value(rv.data(), iv.data(), rv.dtype()); + } + const auto eltype = Type::fromString(meta.storage_type()); SPU_ENFORCE(meta.data_type() != DT_INVALID, "invalid data type={}", diff --git a/libspu/core/value.h b/libspu/core/value.h index 5b9cbeff..162ae32e 100644 --- a/libspu/core/value.h +++ b/libspu/core/value.h @@ -38,11 +38,13 @@ struct ValueProto { class Value final { NdArrayRef data_; + std::optional imag_; DataType dtype_ = DT_INVALID; public: Value() = default; explicit Value(NdArrayRef data, DataType dtype); + explicit Value(NdArrayRef real, NdArrayRef imag, DataType dtype); /// Forward ndarray methods. inline int64_t numel() const { return data_.numel(); } @@ -58,6 +60,9 @@ class Value final { const NdArrayRef& data() const { return data_; } NdArrayRef& data() { return data_; } + const std::optional& imag() const { return imag_; } + std::optional& imag() { return imag_; } + // Get vtype, is readonly and decided by the underline secure compute engine. Visibility vtype() const; bool isPublic() const { return vtype() == VIS_PUBLIC; } @@ -67,6 +72,7 @@ class Value final { DataType dtype() const { return dtype_; } bool isInt() const { return isInteger(dtype()); } bool isFxp() const { return isFixedPoint(dtype()); } + bool isComplex() const { return imag_.has_value(); } // Set dtype. // diff --git a/libspu/device/io.cc b/libspu/device/io.cc index c38772f6..68dcc601 100644 --- a/libspu/device/io.cc +++ b/libspu/device/io.cc @@ -61,6 +61,31 @@ std::vector IoClient::makeShares(const PtBufferView &bv, return result; } + if (bv.pt_type == PT_CF32 || bv.pt_type == PT_CF64) { + auto s_type = bv.pt_type == PT_CF32 ? PT_F32 : PT_F64; + auto offset = bv.pt_type == PT_CF32 ? sizeof(float) : sizeof(double); + // SPDLOG_INFO("bv.strides = {}", bv.strides); + + Strides ds = bv.strides; + for (auto &s : ds) { + s *= 2; + } + + PtBufferView real_view(bv.ptr, s_type, bv.shape, ds); + PtBufferView imag_view((std::byte *)bv.ptr + offset, s_type, bv.shape, ds); + + auto r_shares = makeShares(real_view, vtype, owner_rank); + auto i_shares = makeShares(imag_view, vtype, owner_rank); + + std::vector result; + result.reserve(world_size_); + for (size_t idx = 0; idx < world_size_; ++idx) { + result.emplace_back(r_shares[idx].data(), i_shares[idx].data(), + r_shares[idx].dtype()); + } + return result; + } + // encode to ring. DataType dtype; NdArrayRef encoded = @@ -78,11 +103,50 @@ std::vector IoClient::makeShares(const PtBufferView &bv, return result; } +template +NdArrayRef combineComplex(const NdArrayRef &real, const NdArrayRef &imag, + const Type &complex_type) { + NdArrayRef ret(complex_type, real.shape()); + NdArrayView> ret_v(ret); + NdArrayView rv(real); + NdArrayView iv(imag); + for (int64_t idx = 0; idx < real.numel(); ++idx) { + ret_v[idx].real(rv[idx]); + ret_v[idx].imag(iv[idx]); + } + return ret; +} + NdArrayRef IoClient::combineShares(absl::Span values) { SPU_ENFORCE(values.size() == world_size_, "wrong number of shares, got={}, expect={}", values.size(), world_size_); + if (values.front().isComplex()) { + NdArrayRef real; + NdArrayRef imag; + { + std::vector reals(values.size()); + for (size_t idx = 0; idx < values.size(); ++idx) { + reals[idx] = Value(values[idx].data(), values[idx].dtype()); + } + real = combineShares(reals); + } + { + std::vector imags(values.size()); + for (size_t idx = 0; idx < values.size(); ++idx) { + imags[idx] = Value(*values[idx].imag(), values[idx].dtype()); + } + imag = combineShares(imags); + } + if (values.front().dtype() == DT_F32) { + return combineComplex(real, imag, CF32); + } else { + SPU_ENFORCE(values.front().dtype() == DT_F64); + return combineComplex(real, imag, CF64); + } + } + const size_t fxp_bits = config_.fxp_fraction_bits(); SPU_ENFORCE(fxp_bits != 0, "fxp should never be zero, please check default"); diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index a3520afe..c20fea55 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -1110,6 +1110,25 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, removeValue(sscope, op.getOperand(), opts); } +void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, + mlir::pphlo::RealOp &op, const ExecutionOptions &opts) { + auto v = lookupValue(sscope, op.getOperand(), opts); + addValue(sscope, op.getResult(), kernel::hlo::Real(sctx, v), opts); +} + +void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, + mlir::pphlo::ImagOp &op, const ExecutionOptions &opts) { + auto v = lookupValue(sscope, op.getOperand(), opts); + addValue(sscope, op.getResult(), kernel::hlo::Imag(sctx, v), opts); +} + +void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, + mlir::pphlo::ComplexOp &op, const ExecutionOptions &opts) { + auto r = lookupValue(sscope, op.getLhs(), opts); + auto i = lookupValue(sscope, op.getRhs(), opts); + addValue(sscope, op.getResult(), kernel::hlo::Complex(sctx, r, i), opts); +} + #define DEFINE_UNIMPLEMENTED_OP(OpName) \ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, \ mlir::pphlo::OpName &, const ExecutionOptions &opts) { \ diff --git a/libspu/device/pphlo/pphlo_verifier.h b/libspu/device/pphlo/pphlo_verifier.h index c0ad262c..9b007956 100644 --- a/libspu/device/pphlo/pphlo_verifier.h +++ b/libspu/device/pphlo/pphlo_verifier.h @@ -146,6 +146,9 @@ class PPHloVerifier { NO_VERIFY_DEFN(EpsilonOp) NO_VERIFY_DEFN(CustomCallOp) NO_VERIFY_DEFN(FreeOp) + NO_VERIFY_DEFN(RealOp) + NO_VERIFY_DEFN(ImagOp) + NO_VERIFY_DEFN(ComplexOp) #undef NO_VERIFY_DEFN }; diff --git a/libspu/dialect/pphlo_ops.td b/libspu/dialect/pphlo_ops.td index eb3985be..ae64327f 100644 --- a/libspu/dialect/pphlo_ops.td +++ b/libspu/dialect/pphlo_ops.td @@ -81,10 +81,10 @@ def PPHLO_IotaOp : PPHLO_Op<"iota", [Pure]> { // pphlo unary elementwise op definitions. //===----------------------------------------------------------------------===// class PPHLO_UnaryElementwiseOp traits, - Type TensorType> + Type OperandType, Type ResultType = OperandType> : PPHLO_Op { - let arguments = (ins TensorType : $operand); - let results = (outs TensorType); + let arguments = (ins OperandType : $operand); + let results = (outs ResultType); } def PPHLO_ConvertOp @@ -1138,14 +1138,6 @@ def PPHLO_CustomCallOp: PPHLO_Op<"custom_call", See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call - - Example: - ```mlir - %results = stablehlo.custom_call @foo(%input0) { - backend_config = "bar", - called_computations = [@foo] - } : (tensor) -> tensor - ``` }]; let arguments = (ins @@ -1161,5 +1153,47 @@ def PPHLO_CustomCallOp: PPHLO_Op<"custom_call", }]; } +def PPHLO_RealOp: PPHLO_UnaryElementwiseOp<"real", + [Pure], PPHLO_ComplexTensor, PPHLO_Tensor> { + let summary = "Real operation"; + let description = [{ + Extracts the real part, element-wise, from the `operand` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real + }]; +} + +def PPHLO_ImagOp: PPHLO_UnaryElementwiseOp<"imag", + [Pure], PPHLO_ComplexTensor, PPHLO_Tensor> { + let summary = "Imag operation"; + let description = [{ + Extracts the imaginary part, element-wise, from the `operand` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag + }]; +} + +def PPHLO_ComplexOp: PPHLO_BinaryElementwiseOp<"complex", [Pure, + SameOperandsElementType /*complex_c1*/, + SameOperandsAndResultShape /*complex_c2*/]> { + let summary = "Complex operation"; + let description = [{ + Performs element-wise conversion to a complex value from a pair of real and + imaginary values, `lhs` and `rhs`, and produces a `result` tensor. + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex + }]; + let arguments = (ins + PPHLO_FpTensor:$lhs /*complex_i1*/, + PPHLO_FpTensor:$rhs /*complex_i2*/ + ); + let results = (outs + PPHLO_ComplexTensor:$result + ); +} #endif // SPU_DIALECT_PPHLO_OPS diff --git a/libspu/dialect/pphlo_types.td b/libspu/dialect/pphlo_types.td index 247ad392..a644cb0d 100644 --- a/libspu/dialect/pphlo_types.td +++ b/libspu/dialect/pphlo_types.td @@ -70,4 +70,12 @@ def PPHLO_IntTensor : StaticShapeTensorOf<[PPHLO_PublicIntType, PPHLO_SecretIntT def PPHLO_FpTensor : StaticShapeTensorOf<[PPHLO_PublicFpType, PPHLO_SecretFpType]>; def PPHLO_ScalarIntTensor : 0DTensorOf<[PPHLO_PublicIntType, PPHLO_SecretIntType]>; +def PPHLO_PublicComplexType : Type().getBase().isa<::mlir::ComplexType>()">]>, "public complex type", "::pphlo::PubComplexType">; + +def PPHLO_SecretComplexType : Type().getBase().isa<::mlir::ComplexType>()">]>, "secret complex type", "::pphlo::SecComplexType">; + +def PPHLO_ComplexTensor : StaticShapeTensorOf<[PPHLO_PublicComplexType, PPHLO_SecretComplexType]>; + #endif // SPU_DIALECT_PPHLO_TYPES diff --git a/libspu/kernel/hlo/basic_binary.cc b/libspu/kernel/hlo/basic_binary.cc index bf8211fc..d8f06ab0 100644 --- a/libspu/kernel/hlo/basic_binary.cc +++ b/libspu/kernel/hlo/basic_binary.cc @@ -72,4 +72,14 @@ spu::Value Dot(SPUContext *ctx, const spu::Value &lhs, const spu::Value &rhs) { return hal::matmul(ctx, lhs, rhs); } +spu::Value Complex(SPUContext *ctx, const spu::Value &lhs, + const spu::Value &rhs) { + SPU_ENFORCE(lhs.dtype() == rhs.dtype()); + SPU_ENFORCE(lhs.vtype() == rhs.vtype()); + SPU_ENFORCE(lhs.shape() == rhs.shape()); + SPU_ENFORCE(!lhs.imag().has_value() && !rhs.imag().has_value()); + + return Value(lhs.data(), rhs.data(), lhs.dtype()); +} + } // namespace spu::kernel::hlo diff --git a/libspu/kernel/hlo/basic_binary.h b/libspu/kernel/hlo/basic_binary.h index 7386ccb8..32cbe74d 100644 --- a/libspu/kernel/hlo/basic_binary.h +++ b/libspu/kernel/hlo/basic_binary.h @@ -44,6 +44,7 @@ SIMPLE_BINARY_KERNEL_DECL(Xor) SIMPLE_BINARY_KERNEL_DECL(Div) SIMPLE_BINARY_KERNEL_DECL(Remainder) SIMPLE_BINARY_KERNEL_DECL(Dot) +SIMPLE_BINARY_KERNEL_DECL(Complex) #undef SIMPLE_BINARY_KERNEL_DECL diff --git a/libspu/kernel/hlo/basic_unary.cc b/libspu/kernel/hlo/basic_unary.cc index b14b13de..5a30234e 100644 --- a/libspu/kernel/hlo/basic_unary.cc +++ b/libspu/kernel/hlo/basic_unary.cc @@ -87,4 +87,14 @@ spu::Value Round_AFZ(SPUContext *ctx, const spu::Value &in) { return hal::dtype_cast(ctx, hal::dtype_cast(ctx, round, DT_I64), in.dtype()); } +spu::Value Real(SPUContext *ctx, const spu::Value &in) { + SPU_ENFORCE(in.imag().has_value(), "In must be a complex value"); + return Value(in.data(), in.dtype()); +} + +spu::Value Imag(SPUContext *ctx, const spu::Value &in) { + SPU_ENFORCE(in.imag().has_value(), "In must be a complex value"); + return Value(*in.imag(), in.dtype()); // NOLINT +} + } // namespace spu::kernel::hlo diff --git a/libspu/kernel/hlo/basic_unary.h b/libspu/kernel/hlo/basic_unary.h index 8048cb95..3be84357 100644 --- a/libspu/kernel/hlo/basic_unary.h +++ b/libspu/kernel/hlo/basic_unary.h @@ -39,6 +39,8 @@ SIMPLE_UNARY_KERNEL_DECL(Rsqrt) SIMPLE_UNARY_KERNEL_DECL(Sqrt) SIMPLE_UNARY_KERNEL_DECL(Sign) SIMPLE_UNARY_KERNEL_DECL(Round_AFZ) +SIMPLE_UNARY_KERNEL_DECL(Real) +SIMPLE_UNARY_KERNEL_DECL(Imag) #undef SIMPLE_UNARY_KERNEL_DECL diff --git a/libspu/kernel/hlo/select_and_scatter_test.cc b/libspu/kernel/hlo/select_and_scatter_test.cc index 6bd3d8d1..182e8ecd 100644 --- a/libspu/kernel/hlo/select_and_scatter_test.cc +++ b/libspu/kernel/hlo/select_and_scatter_test.cc @@ -63,7 +63,7 @@ TEST_P(SelectAndScatterTest, ParamTest) { EXPECT_TRUE(xt::allclose(expected, ret_hat, 0.01, 0.001)); } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( SelectAndScatterTest_Instantiation, SelectAndScatterTest, ::testing::Values( SelectAndScatterTestParam{{1, 9, 3, 7, 5, 6}, diff --git a/libspu/mpc/cheetah/rlwe/lwe_ct.cc b/libspu/mpc/cheetah/rlwe/lwe_ct.cc index 66e3a55f..4abc8c67 100644 --- a/libspu/mpc/cheetah/rlwe/lwe_ct.cc +++ b/libspu/mpc/cheetah/rlwe/lwe_ct.cc @@ -521,14 +521,9 @@ void LWECt::CastAsRLWE(const seal::SEALContext &context, uint64_t multiplier, out->scale() = 1.0; } -void PhantomLWECt::WrapIt(const RLWECt &ct, size_t coeff_index, - bool only_wrap_zero) { +void PhantomLWECt::WrapIt(const RLWECt &ct, size_t coeff_index) { SPU_ENFORCE(not ct.is_ntt_form() && ct.size() == 2 && coeff_index < ct.poly_modulus_degree()); - only_wrap_zero_ = only_wrap_zero; - if (only_wrap_zero) { - SPU_ENFORCE(coeff_index == 0); - } coeff_index_ = coeff_index; pid_ = ct.parms_id(); base_ = &ct; @@ -602,24 +597,12 @@ void PhantomLWECt::CastAsRLWE(const seal::SEALContext &context, fixed_mul, modulus[l]); } - if (coeff_index_ == 0 && only_wrap_zero_) { - auto ct0_ptr = base_->data(0) + l * num_coeff; - uint64_t acc = 0; - for (size_t i = 0; i < num_coeff; ++i) { - acc = add_uint_mod(ct0_ptr[i], acc, modulus[l]); - } - acc = - multiply_uint_mod(acc, ntt_tables[l].inv_degree_modulo(), modulus[l]); - acc = multiply_uint_mod(acc, fixed_mul, modulus[l]); - out->data(0)[l * num_coeff] = acc; - } else { - out->data(0)[l * num_coeff] = multiply_uint_mod( - base_->data(0)[l * num_coeff + coeff_index_], fixed_mul, modulus[l]); - } + out->data(0)[l * num_coeff] = multiply_uint_mod( + base_->data(0)[l * num_coeff + coeff_index_], fixed_mul, modulus[l]); src_ptr += num_coeff; dst_ptr += num_coeff; } } -} // namespace spu::mpc::cheetah \ No newline at end of file +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/rlwe/lwe_ct.h b/libspu/mpc/cheetah/rlwe/lwe_ct.h index 15b8da6f..617a543c 100644 --- a/libspu/mpc/cheetah/rlwe/lwe_ct.h +++ b/libspu/mpc/cheetah/rlwe/lwe_ct.h @@ -24,7 +24,8 @@ class LWECt { ~LWECt(); - // RLWE(\sum_{i} a_iX^i), k -> LWE(a_k) + // Gvien RLWE(\sum_{i} a_iX^i), k to create a valid LWE ciphertext that + // decrypts to `a_k` LWECt(const RLWECt &rlwe, size_t coeff_index, const seal::SEALContext &context); @@ -40,14 +41,22 @@ class LWECt { LWECt &SubPlainInplace(const std::vector &plain, const seal::SEALContext &context); + // self += other without doing the modulus reduction + // The `other` LWE is extracted from the RLWE on-the-fly LWECt &AddLazyInplace(const RLWECt &rlwe, size_t coeff_index, const seal::SEALContext &context); + // self -= other without doing the modulus reduction + // The `other` LWE is extracted from the RLWE on-the-fly LWECt &SubLazyInplace(const RLWECt &rlwe, size_t coeff_index, const seal::SEALContext &context); - // LWE(a), multiplier -> RLWE(multiplier * a + \sum_{i>0} a_iX^i) for some - // random \{a_i\} + // Simply cast the LWE as an RLWE such that + // the decryption of the RLWE gives the same value in the 0-th coefficient + // i.e., Dec(LWE) = multiplier * Dec(RLWE)[0] + // + // Ref to Section 3.3 ``Efficient Homomorphic Conversion Between (Ring) LWE + // Ciphertexts`` https://eprint.iacr.org/2020/015.pdf void CastAsRLWE(const seal::SEALContext &context, uint64_t multiplier, RLWECt *out) const; @@ -112,18 +121,22 @@ class LWECt { RLWEPt vec_; }; +// We aim to lazy the moduluo reduction in x = x + y mod p given y \in [0, p). +// Because the accumulator `x` is stored in uint64_t, and thus we can lazy +// 62 - log_2(p) times. size_t MaximumLazy(const seal::SEALContext &context); -// Just wrap an RLWE(m(X)) and view it as an LWE of the k-th coefficient, -// LWE(m_k) +// clang-format off +// Wrap an RLWE(m(X)) and view it as an LWE of the k-th coefficient i.e., LWE(m[k]). +// This wrapper class is used to avoid intensive memory allocation. +// clang-format on class PhantomLWECt { public: PhantomLWECt() = default; ~PhantomLWECt() = default; - void WrapIt(const RLWECt &ct, size_t coeff_index, - bool only_wrap_zero = false); + void WrapIt(const RLWECt &ct, size_t coeff_index); seal::parms_id_type parms_id() const { return pid_; } @@ -135,14 +148,19 @@ class PhantomLWECt { bool IsValid() const; + // Simply cast the LWE as an RLWE such that + // the decryption of the RLWE gives the same value in the 0-th coefficient + // i.e., Dec(LWE) = multiplier * Dec(RLWE)[0] + // + // Ref to Section 3.3 ``Efficient Homomorphic Conversion Between (Ring) LWE + // Ciphertexts`` https://eprint.iacr.org/2020/015.pdf void CastAsRLWE(const seal::SEALContext &context, uint64_t multiplier, RLWECt *out) const; private: - bool only_wrap_zero_ = false; size_t coeff_index_ = 0; seal::parms_id_type pid_ = seal::parms_id_zero; const RLWECt *base_ = nullptr; }; -} // namespace spu::mpc::cheetah \ No newline at end of file +} // namespace spu::mpc::cheetah diff --git a/libspu/psi/core/ecdh_oprf_psi.cc b/libspu/psi/core/ecdh_oprf_psi.cc index d056f827..eb3c4c9c 100644 --- a/libspu/psi/core/ecdh_oprf_psi.cc +++ b/libspu/psi/core/ecdh_oprf_psi.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -274,8 +275,11 @@ void EcdhOprfPsiServer::RecvBlindAndShuffleSendEvaluate() { std::vector batch_evaluated_items = oprf_server_->Evaluate(blinded_items); - evaluated_items.insert(evaluated_items.end(), batch_evaluated_items.begin(), - batch_evaluated_items.end()); + // evaluated_items is scoped and will be destructed soon + evaluated_items.insert( + evaluated_items.end(), + std::make_move_iterator(batch_evaluated_items.begin()), + std::make_move_iterator(batch_evaluated_items.end())); batch_count++; } @@ -348,8 +352,10 @@ EcdhOprfPsiServer::RecvIntersectionMaskedItems( idx * compare_length, compare_length); } - client_masked_items.insert(batch_masked_items.begin(), - batch_masked_items.end()); + // batch_masked_items is scoped and will be destructed soon + client_masked_items.insert( + std::make_move_iterator(batch_masked_items.begin()), + std::make_move_iterator(batch_masked_items.end())); batch_count++; } @@ -401,8 +407,10 @@ EcdhOprfPsiServer::RecvIntersectionMaskedItems( f_compare[i].get(); } - for (const auto& r : batch_result) { - indices.insert(indices.end(), r.begin(), r.end()); + // batch_result is scoped and will be destructed soon + for (auto& r : batch_result) { + indices.insert(indices.end(), std::make_move_iterator(r.begin()), + std::make_move_iterator(r.end())); } batch_count++; diff --git a/libspu/spu.proto b/libspu/spu.proto index 67a8e40c..53f6abaf 100644 --- a/libspu/spu.proto +++ b/libspu/spu.proto @@ -74,12 +74,16 @@ enum PtType { PT_U32 = 6; // uint32_t PT_I64 = 7; // int64_t PT_U64 = 8; // uint64_t - PT_F32 = 9; // float - PT_F64 = 10; // double - PT_I128 = 11; // int128_t - PT_U128 = 12; // uint128_t - PT_BOOL = 13; // bool - PT_F16 = 14; // half + PT_I128 = 9; // int128_t + PT_U128 = 10; // uint128_t + PT_BOOL = 11; // bool + + PT_F16 = 30; // half + PT_F32 = 31; // float + PT_F64 = 32; // double + + PT_CF32 = 50; // complex float + PT_CF64 = 51; // complex double } // A security parameter type. @@ -120,17 +124,18 @@ enum ProtocolKind { message ValueMetaProto { // The data type. DataType data_type = 1; + bool is_complex = 2; // The data visibility. - Visibility visibility = 2; + Visibility visibility = 3; // The shape of the value. - ShapeProto shape = 3; + ShapeProto shape = 4; // The storage type, defined by the underline evaluation engine. // i.e. `aby3.AShr` means an aby3 arithmetic share in FM64. // usually, the application does not care about this attribute. - string storage_type = 4; + string storage_type = 5; } // The spu Value proto, used for spu value serialization. diff --git a/sml/decomposition/emulations/pca_emul.py b/sml/decomposition/emulations/pca_emul.py index bdc088a7..33b20def 100644 --- a/sml/decomposition/emulations/pca_emul.py +++ b/sml/decomposition/emulations/pca_emul.py @@ -28,28 +28,21 @@ def emul_powerPCA(mode: emulation.Mode.MULTIPROCESS): - def proc(X): + print("start power method emulation.") + + def proc_transform(X): model = PCA( method='power_iteration', n_components=2, + max_power_iter=200, ) model.fit(X) X_transformed = model.transform(X) X_variances = model._variances + X_reconstructed = model.inverse_transform(X_transformed) - return X_transformed, X_variances - - def proc_reconstruct(X): - model = PCA( - method='power_iteration', - n_components=2, - ) - - model.fit(X) - X_reconstructed = model.inverse_transform(model.transform(X)) - - return X_reconstructed + return X_transformed, X_variances, X_reconstructed try: # bandwidth and latency only work for docker mode @@ -57,11 +50,12 @@ def proc_reconstruct(X): emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 ) emulator.up() + # Create a simple dataset - X = random.normal(random.PRNGKey(0), (15, 100)) - result = emulator.run(proc)(X) - print("X_transformed_jax: ", result[0]) - print("X_transformed_jax: ", result[1]) + X = random.normal(random.PRNGKey(0), (10, 20)) + X_spu = emulator.seal(X) + result = emulator.run(proc_transform)(X_spu) + # The transformed data should have 2 dimensions assert result[0].shape[1] == 2 # The mean of the transformed data should be approximately 0 @@ -70,24 +64,20 @@ def proc_reconstruct(X): # Compare with sklearn model = SklearnPCA(n_components=2) model.fit(X) - X_transformed = model.transform(X) + X_transformed_sklearn = model.transform(X) X_variances = model.explained_variance_ - print("X_transformed_sklearn: ", X_transformed) - print("X_variances_sklearn: ", X_variances) - - result = emulator.run(proc_reconstruct)(X) - - print("X_reconstructed_jax: ", result) + # Compare the transform results(omit sign) + np.testing.assert_allclose( + np.abs(X_transformed_sklearn), np.abs(result[0]), rtol=0.1, atol=0.1 + ) - # Compare with sklearn - model = SklearnPCA(n_components=2) - model.fit(X) - X_reconstructed = model.inverse_transform(model.transform(X)) + # Compare the variance results + np.testing.assert_allclose(X_variances, result[1], rtol=0.1, atol=0.1) - print("X_reconstructed_sklearn: ", X_reconstructed) + X_reconstructed = model.inverse_transform(X_transformed_sklearn) - assert np.allclose(X_reconstructed, result, atol=1e-3) + np.testing.assert_allclose(X_reconstructed, result[2], atol=1e-3) finally: emulator.down() diff --git a/sml/decomposition/emulations/rsvd_pca_emul.py b/sml/decomposition/emulations/rsvd_pca_emul.py index c38c16d5..3a83ff93 100644 --- a/sml/decomposition/emulations/rsvd_pca_emul.py +++ b/sml/decomposition/emulations/rsvd_pca_emul.py @@ -54,11 +54,11 @@ def proc(X, random_matrix, n_components, n_oversamples, max_power_iter, scale): emulator.up() # Create a simple dataset - X = random.normal(random.PRNGKey(0), (1000, 20)) + X = random.normal(random.PRNGKey(0), (50, 20)) X_spu = emulator.seal(X) - n_components = 5 + n_components = 1 n_oversamples = 10 - max_power_iter = 300 + max_power_iter = 100 scale = (10000000, 10000) # Create random_matrix @@ -71,9 +71,6 @@ def proc(X, random_matrix, n_components, n_oversamples, max_power_iter, scale): result = emulator.run(proc, static_argnums=(2, 3, 4, 5))( X_spu, random_matrix_spu, n_components, n_oversamples, max_power_iter, scale ) - print("X_transformed_spu: ", result[0][:5, :]) - print("X_variance_spu: ", result[1]) - print("X_reconstructed_spu:", result[2][:5, :]) # The transformed data should have 2 dimensions assert result[0].shape[1] == n_components @@ -89,13 +86,17 @@ def proc(X, random_matrix, n_components, n_oversamples, max_power_iter, scale): random_state=0, ) model.fit(X) - X_transformed = model.transform(X) + X_transformed_sklearn = model.transform(X) X_variances = model.explained_variance_ - X_reconstructed = model.inverse_transform(X_transformed) + X_reconstructed = model.inverse_transform(X_transformed_sklearn) + + # Compare the transform results(omit sign) + np.testing.assert_allclose( + np.abs(X_transformed_sklearn), np.abs(result[0]), rtol=1, atol=0.1 + ) - print("X_transformed_sklearn: ", X_transformed[:5, :]) - print("X_variances_sklearn: ", X_variances) - print("X_reconstructed_sklearn: ", X_reconstructed[:5, :]) + # Compare the variance results + np.testing.assert_allclose(X_variances, result[1], rtol=1, atol=0.1) assert np.allclose(X_reconstructed, result[2], atol=1e-1) diff --git a/sml/decomposition/tests/pca_test.py b/sml/decomposition/tests/pca_test.py index 2e3f398a..717bf008 100644 --- a/sml/decomposition/tests/pca_test.py +++ b/sml/decomposition/tests/pca_test.py @@ -30,29 +30,44 @@ class UnitTests(unittest.TestCase): - def test_power(self): - sim = spsim.Simulator.simple( + @classmethod + def setUpClass(cls): + print(" ========= start test of pca package ========= \n") + + # 1. init sim + cls.sim64 = spsim.Simulator.simple( 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 ) + config128 = spu_pb2.RuntimeConfig( + protocol=spu_pb2.ProtocolKind.ABY3, + field=spu_pb2.FieldType.FM128, + fxp_fraction_bits=30, + ) + cls.sim128 = spsim.Simulator(3, config128) + + def test_power(self): + print("start test power method.") # Test fit_transform def proc_transform(X): model = PCA( method='power_iteration', n_components=2, + max_power_iter=200, ) model.fit(X) X_transformed = model.transform(X) X_variances = model._variances + X_reconstructed = model.inverse_transform(X_transformed) - return X_transformed, X_variances + return X_transformed, X_variances, X_reconstructed # Create a simple dataset - X = random.normal(random.PRNGKey(0), (15, 100)) + X = random.normal(random.PRNGKey(0), (10, 20)) # Run the simulation - result = spsim.sim_jax(sim, proc_transform)(X) + result = spsim.sim_jax(self.sim64, proc_transform)(X) # The transformed data should have 2 dimensions self.assertEqual(result[0].shape[1], 2) @@ -66,45 +81,26 @@ def proc_transform(X): sklearn_pca = SklearnPCA(n_components=2) X_transformed_sklearn = sklearn_pca.fit_transform(X_np) - # Compare the transform results - print("X_transformed_sklearn: ", X_transformed_sklearn) - print("X_transformed_jax", result[0]) + # Compare the transform results(omit sign) + np.testing.assert_allclose( + np.abs(X_transformed_sklearn), np.abs(result[0]), rtol=0.1, atol=0.1 + ) # Compare the variance results - print( - "X_transformed_sklearn.explained_variance_: ", - sklearn_pca.explained_variance_, + np.testing.assert_allclose( + sklearn_pca.explained_variance_, result[1], rtol=0.1, atol=0.1 ) - print("X_transformed_jax.explained_variance_: ", result[1]) - - # Test inverse_transform - def proc_reconstruct(X): - model = PCA( - method='power_iteration', - n_components=2, - ) - - model.fit(X) - X_reconstructed = model.inverse_transform(model.transform(X)) - - return X_reconstructed - - # Run the simulation - result = spsim.sim_jax(sim, proc_reconstruct)(X) # Run inverse_transform using sklearn X_reconstructed_sklearn = sklearn_pca.inverse_transform(X_transformed_sklearn) # Compare the results - self.assertTrue(np.allclose(X_reconstructed_sklearn, result, atol=0.01)) + np.testing.assert_allclose( + X_reconstructed_sklearn, result[2], atol=0.01, rtol=0.01 + ) def test_rsvd(self): - config = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.ABY3, - field=spu_pb2.FieldType.FM128, - fxp_fraction_bits=30, - ) - sim = spsim.Simulator(3, config) + print("start test rsvd method.") # Test fit_transform def proc_transform(X, random_matrix): @@ -114,17 +110,22 @@ def proc_transform(X, random_matrix): n_oversamples=n_oversamples, random_matrix=random_matrix, scale=[10000000, 10000], + max_power_iter=100, ) model.fit(X) X_transformed = model.transform(X) X_variances = model._variances + X_reconstructed = model.inverse_transform(X_transformed) - return X_transformed, X_variances + return X_transformed, X_variances, X_reconstructed # Create a simple dataset - X = random.normal(random.PRNGKey(0), (1000, 20)) - n_components = 5 + # Note: + # 1. better for large sample data, like (1000, 20) + # 2. for small data, it may corrupt because the projection will have large error + X = random.normal(random.PRNGKey(0), (50, 20)) + n_components = 1 n_oversamples = 10 # Create random_matrix @@ -134,7 +135,7 @@ def proc_transform(X, random_matrix): ) # Run the simulation - result = spsim.sim_jax(sim, proc_transform)(X, random_matrix) + result = spsim.sim_jax(self.sim128, proc_transform)(X, random_matrix) # The transformed data should have 2 dimensions self.assertEqual(result[0].shape[1], n_components) @@ -154,40 +155,21 @@ def proc_transform(X, random_matrix): sklearn_pca.fit(X_np) X_transformed_sklearn = sklearn_pca.transform(X_np) - # Compare the transform results - print("X_transformed_sklearn: ", X_transformed_sklearn) - print("X_transformed_jax", result[0]) + # Compare the transform results(omit sign) + np.testing.assert_allclose( + np.abs(X_transformed_sklearn), np.abs(result[0]), rtol=1, atol=0.1 + ) # Compare the variance results - print( - "X_transformed_sklearn.explained_variance_: ", - sklearn_pca.explained_variance_, + np.testing.assert_allclose( + sklearn_pca.explained_variance_, result[1], rtol=1, atol=0.1 ) - print("X_transformed_jax.explained_variance_: ", result[1]) - - # Test inverse_transform - def proc_reconstruct(X, random_matrix): - model = PCA( - method='rsvd', - n_components=n_components, - n_oversamples=n_oversamples, - random_matrix=random_matrix, - scale=[10000000, 10000], - ) - - model.fit(X) - X_reconstructed = model.inverse_transform(model.transform(X)) - - return X_reconstructed - - # Run the simulation - result = spsim.sim_jax(sim, proc_reconstruct)(X, random_matrix) # Run inverse_transform using sklearn X_reconstructed_sklearn = sklearn_pca.inverse_transform(X_transformed_sklearn) # Compare the results - self.assertTrue(np.allclose(X_reconstructed_sklearn, result, atol=1e-1)) + np.testing.assert_allclose(X_reconstructed_sklearn, result[2], atol=0.1, rtol=1) def test_rsvd(self): config = spu_pb2.RuntimeConfig( diff --git a/spu/libspu.cc b/spu/libspu.cc index 9042b855..471b3151 100644 --- a/spu/libspu.cc +++ b/spu/libspu.cc @@ -364,21 +364,23 @@ class RuntimeWrapper { void Clear() { env_.clear(); } }; +// numpy type naming: +// https://numpy.org/doc/stable/reference/arrays.scalars.html#sized-aliases #define FOR_PY_FORMATS(FN) \ - FN("b", PT_I8) \ - FN("h", PT_I16) \ - FN("i", PT_I32) \ - FN("l", PT_I64) \ - FN("q", PT_I64) \ - FN("B", PT_U8) \ - FN("H", PT_U16) \ - FN("I", PT_U32) \ - FN("L", PT_U64) \ - FN("Q", PT_U64) \ - FN("e", PT_F16) \ - FN("f", PT_F32) \ - FN("d", PT_F64) \ - FN("?", PT_BOOL) + FN("int8", PT_I8) \ + FN("int16", PT_I16) \ + FN("int32", PT_I32) \ + FN("int64", PT_I64) \ + FN("uint8", PT_U8) \ + FN("uint16", PT_U16) \ + FN("uint32", PT_U32) \ + FN("uint64", PT_U64) \ + FN("float16", PT_F16) \ + FN("float32", PT_F32) \ + FN("float64", PT_F64) \ + FN("bool", PT_BOOL) \ + FN("complex64", PT_CF32) \ + FN("complex128", PT_CF64) // https://docs.python.org/3/library/struct.html#format-characters // https://numpy.org/doc/stable/reference/arrays.scalars.html#built-in-scalar-types @@ -447,7 +449,7 @@ class IoWrapper { size_t GetShareChunkCount(const py::array& arr, int visibility, int owner_rank) { const py::buffer_info& binfo = arr.request(); - const PtType pt_type = PyFormatToPtType(binfo.format); + const PtType pt_type = PyFormatToPtType(py::str(arr.dtype())); spu::PtBufferView view( binfo.ptr, pt_type, Shape(binfo.shape.begin(), binfo.shape.end()), @@ -465,8 +467,9 @@ class IoWrapper { // cost SizeCheck(); + const PtType pt_type = PyFormatToPtType(py::str(arr.dtype())); + const py::buffer_info& binfo = arr.request(); - const PtType pt_type = PyFormatToPtType(binfo.format); spu::PtBufferView view( binfo.ptr, pt_type, Shape(binfo.shape.begin(), binfo.shape.end()), diff --git a/spu/tests/spu_io_test.py b/spu/tests/spu_io_test.py index 580c0f83..9953e4e0 100644 --- a/spu/tests/spu_io_test.py +++ b/spu/tests/spu_io_test.py @@ -220,6 +220,72 @@ def test_io_scalar(self, wsize, prot, field, chunk_size): npt.assert_almost_equal(x, y, decimal=5) + def test_io_single_complex(self, wsize, prot, field, chunk_size): + if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: + return + + config = spu_pb2.RuntimeConfig( + protocol=prot, + field=field, + share_max_chunk_size=chunk_size, + ) + io = ppapi.Io(wsize, config) + + # SFXP + x = np.array([1 + 2j, 3 + 4j, 5 + 6j]).astype('complex64') + + xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs), wsize) + + y = io.reconstruct(xs) + print(xs) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs[0].share_chunks), 2 * chunk_count) + + npt.assert_almost_equal(x, y, decimal=5) + + # PFXP + xs = io.make_shares(x, spu_pb2.Visibility.VIS_PUBLIC) + self.assertEqual(len(xs), wsize) + y = io.reconstruct(xs) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_PUBLIC) + self.assertEqual(len(xs[0].share_chunks), 2 * chunk_count) + + npt.assert_almost_equal(x, y, decimal=5) + + def test_io_double_complex(self, wsize, prot, field, chunk_size): + if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: + return + + config = spu_pb2.RuntimeConfig( + protocol=prot, + field=field, + share_max_chunk_size=chunk_size, + ) + io = ppapi.Io(wsize, config) + + # SFXP + x = np.array([1 + 2j, 3 + 4j, 5 + 6j]) + + xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs), wsize) + + y = io.reconstruct(xs) + print(xs) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs[0].share_chunks), 2 * chunk_count) + + npt.assert_almost_equal(x, y, decimal=5) + + # PFXP + xs = io.make_shares(x, spu_pb2.Visibility.VIS_PUBLIC) + self.assertEqual(len(xs), wsize) + y = io.reconstruct(xs) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_PUBLIC) + self.assertEqual(len(xs[0].share_chunks), 2 * chunk_count) + + npt.assert_almost_equal(x, y, decimal=5) + if __name__ == '__main__': unittest.main()