Skip to content

Commit

Permalink
Repo sync (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Aug 25, 2023
1 parent a62c779 commit 9cd51e6
Show file tree
Hide file tree
Showing 32 changed files with 609 additions and 349 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- [Feature] Support half type
- [Feature] Add Psi Progress
- [Feature] Add SineOp/CosineOp support
- [Feature] Add complex support

## 20230705

Expand Down
304 changes: 145 additions & 159 deletions docs/reference/xla_status.md

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion examples/python/ml/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,10 @@ def suite():
suite.addTest(UnitTests('test_ss_xgb'))
suite.addTest(UnitTests('test_stax_mnist_classifier'))
suite.addTest(UnitTests('test_stax_nn'))
suite.addTest(UnitTests('test_save_and_load_model'))
# should put JAX tests above
suite.addTest(UnitTests('test_tf_experiment'))
suite.addTest(UnitTests('test_torch_experiment'))
suite.addTest(UnitTests('test_save_and_load_model'))
return suite


Expand Down
1 change: 1 addition & 0 deletions libspu/compiler/front_end/fe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ void FE::buildFrontEndPipeline(mlir::PassManager *pm, const std::string &args) {
pm->addPass(mlir::mhlo::createExpandHloTuplesPass());

auto &optPM = pm->nest<mlir::func::FuncOp>();
optPM.addPass(mlir::mhlo::createLowerComplexPass());
optPM.addPass(mlir::mhlo::createLegalizeEinsumToDotGeneralPass());
optPM.addPass(mlir::mhlo::createLegalizeGeneralDotPass());
optPM.addPass(mlir::mhlo::createSinkConstantsToControlFlowPass());
Expand Down
6 changes: 5 additions & 1 deletion libspu/compiler/passes/hlo_legalize_to_pphlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class HloToPPHloTypeConverter : public TypeConverter {
Type oriElmTy = type.getElementType();
Type newElmTy;
if (oriElmTy.isa<::mlir::FloatType>() ||
oriElmTy.isa<::mlir::IntegerType>()) {
oriElmTy.isa<::mlir::IntegerType>() ||
oriElmTy.isa<::mlir::ComplexType>()) {
newElmTy = ::mlir::pphlo::UnsetType::get(oriElmTy);
} else {
newElmTy = oriElmTy;
Expand Down Expand Up @@ -1531,6 +1532,9 @@ struct HloLegalizeToPPHlo
.insert<FuncOpConverter, ReturnOpConverter, HloCompToPPHloOpConverter,
CustomCallConverter, ReduceOpConverter, ReduceWindowOpConverter,
WhileOpConverter, IfOpConverter, CaseOpConverter,
HloToPPHloOpConverter<stablehlo::RealOp>,
HloToPPHloOpConverter<stablehlo::ImagOp>,
HloToPPHloOpConverter<stablehlo::ComplexOp>,
HloToPPHloOpConverter<stablehlo::AbsOp>,
HloToPPHloOpConverter<stablehlo::AddOp>,
HloToPPHloOpConverter<stablehlo::AndOp>,
Expand Down
3 changes: 3 additions & 0 deletions libspu/compiler/passes/map_stablehlo_to_pphlo_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ MAP_HLO_TO_PPHLO(SubtractOp)
MAP_HLO_TO_PPHLO(TanhOp)
MAP_HLO_TO_PPHLO(TransposeOp)
MAP_HLO_TO_PPHLO(XorOp)
MAP_HLO_TO_PPHLO(RealOp)
MAP_HLO_TO_PPHLO(ImagOp)
MAP_HLO_TO_PPHLO(ComplexOp)

MAP_HLO_TO_PPHLO_DIFF_NAME(BroadcastInDimOp, BroadcastOp)

Expand Down
11 changes: 11 additions & 0 deletions libspu/compiler/tests/hlo_to_pphlo_complex.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: mlir-pphlo-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_PUBLIC --split-input-file %s | FileCheck %s

func.func @main(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>> {
// CHECK: pphlo.real
%0 = stablehlo.real %arg0 : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
// CHECK: pphlo.imag
%1 = stablehlo.imag %arg0 : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
// CHECK: pphlo.complex
%2 = stablehlo.complex %0, %1 : tensor<3xcomplex<f32>>
return %2 : tensor<3xcomplex<f32>>
}
2 changes: 2 additions & 0 deletions libspu/core/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ Type F64 = makePtType(PT_F64);
Type I128 = makePtType(PT_I128);
Type U128 = makePtType(PT_U128);
Type BOOL = makePtType(PT_BOOL);
Type CF32 = makePtType(PT_CF32);
Type CF64 = makePtType(PT_CF64);

bool isFloatTy(const Type& type) {
if (!type.isa<PtTy>()) {
Expand Down
2 changes: 2 additions & 0 deletions libspu/core/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ extern Type F32;
extern Type F64;
extern Type I128;
extern Type U128;
extern Type CF32;
extern Type CF64;

class RingTy : public TypeImpl<RingTy, TypeObject, Ring2k> {
using Base = TypeImpl<RingTy, TypeObject, Ring2k>;
Expand Down
1 change: 1 addition & 0 deletions libspu/core/type_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ size_t SizeOf(PtType ptt) {
case PT_INVALID:
return 0;
FOREACH_PT_TYPES(CASE);
FOREACH_COMPLEX_PT_TYPES(CASE);
default:
SPU_THROW("unknown size of {}", ptt);
}
Expand Down
5 changes: 5 additions & 0 deletions libspu/core/type_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <complex>
#include <numeric>

#include "fmt/format.h"
Expand Down Expand Up @@ -88,6 +89,10 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype);
FN(PT_U128, uint128_t, U128) \
FN(PT_BOOL, bool, I1)

#define FOREACH_COMPLEX_PT_TYPES(FN) \
FN(PT_CF32, std::complex<float>, CF32) \
FN(PT_CF64, std::complex<double>, CF64)

#define FOREACH_PT_TYPES(FN) \
FOREACH_INT_PT_TYPES(FN) \
FOREACH_FLOAT_PT_TYPES(FN)
Expand Down
50 changes: 40 additions & 10 deletions libspu/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ Visibility getVisibilityFromType(const Type& ty) {
Value::Value(NdArrayRef data, DataType dtype)
: data_(std::move(data)), dtype_(dtype) {}

Value::Value(NdArrayRef real, NdArrayRef imag, DataType dtype)
: data_(std::move(real)), imag_(std::move(imag)), dtype_(dtype) {
SPU_ENFORCE(data_.eltype() == imag_->eltype());
}

Visibility Value::vtype() const {
return getVisibilityFromType(storage_type());
};
Expand All @@ -66,15 +71,15 @@ size_t Value::chunksCount(size_t max_chunk_size) const {

ValueProto Value::toProto(size_t max_chunk_size) const {
SPU_ENFORCE(max_chunk_size > 0);
SPU_ENFORCE(dtype_ != DT_INVALID && vtype() != VIS_INVALID);
SPU_ENFORCE(dtype_ != DT_INVALID && vtype() != VIS_INVALID, "{}", *this);

ValueProto ret;

auto build_chunk = [&](const void* data, size_t size, size_t num_chunks) {
if (size == 0) {
return;
}
ret.chunks.reserve(num_chunks);
ret.chunks.reserve(ret.chunks.size() + num_chunks);
for (size_t i = 0; i < num_chunks; i++) {
size_t chunk_size = std::min(max_chunk_size, size - i * max_chunk_size);

Expand All @@ -92,15 +97,21 @@ ValueProto Value::toProto(size_t max_chunk_size) const {

const size_t num_chunks = chunksCount(max_chunk_size);

if (data_.isCompact()) {
build_chunk(data_.data(), numel() * data_.elsize(), num_chunks);
} else {
// Make a compact clone
auto copy = data_.clone();
SPU_ENFORCE(copy.isCompact(), "Must be a compact copy.");
build_chunk(copy.data(), copy.buf()->size(), num_chunks);
}
auto array_to_chunks = [&](const NdArrayRef& a) {
if (a.isCompact()) {
build_chunk(a.data(), numel() * a.elsize(), num_chunks);
} else {
// Make a compact clone
auto copy = a.clone();
SPU_ENFORCE(copy.isCompact(), "Must be a compact copy.");
build_chunk(copy.data(), copy.buf()->size(), num_chunks);
}
};

array_to_chunks(data_);
if (imag_) {
array_to_chunks(*imag_);
}
ret.meta.CopyFrom(toMetaProto());

return ret;
Expand All @@ -111,6 +122,7 @@ ValueMetaProto Value::toMetaProto() const {

ValueMetaProto proto;
proto.set_data_type(dtype_);
proto.set_is_complex(isComplex());
proto.set_visibility(vtype());
for (const auto& d : shape()) {
proto.mutable_shape()->add_dims(d);
Expand All @@ -121,6 +133,24 @@ ValueMetaProto Value::toMetaProto() const {

Value Value::fromProto(const ValueProto& value) {
const auto& meta = value.meta;
if (meta.is_complex()) {
// real
ValueMetaProto partial = value.meta;
partial.set_is_complex(false);
ValueProto partial_proto;
partial_proto.meta = partial;
auto n = value.chunks.size() / 2;
std::copy_n(value.chunks.begin(), n,
std::back_inserter(partial_proto.chunks));
auto rv = fromProto(partial_proto);

partial_proto.chunks.clear();
std::copy_n(value.chunks.begin() + n, n,
std::back_inserter(partial_proto.chunks));
auto iv = fromProto(partial_proto);
return Value(rv.data(), iv.data(), rv.dtype());
}

const auto eltype = Type::fromString(meta.storage_type());

SPU_ENFORCE(meta.data_type() != DT_INVALID, "invalid data type={}",
Expand Down
6 changes: 6 additions & 0 deletions libspu/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ struct ValueProto {

class Value final {
NdArrayRef data_;
std::optional<NdArrayRef> imag_;
DataType dtype_ = DT_INVALID;

public:
Value() = default;
explicit Value(NdArrayRef data, DataType dtype);
explicit Value(NdArrayRef real, NdArrayRef imag, DataType dtype);

/// Forward ndarray methods.
inline int64_t numel() const { return data_.numel(); }
Expand All @@ -58,6 +60,9 @@ class Value final {
const NdArrayRef& data() const { return data_; }
NdArrayRef& data() { return data_; }

const std::optional<NdArrayRef>& imag() const { return imag_; }
std::optional<NdArrayRef>& imag() { return imag_; }

// Get vtype, is readonly and decided by the underline secure compute engine.
Visibility vtype() const;
bool isPublic() const { return vtype() == VIS_PUBLIC; }
Expand All @@ -67,6 +72,7 @@ class Value final {
DataType dtype() const { return dtype_; }
bool isInt() const { return isInteger(dtype()); }
bool isFxp() const { return isFixedPoint(dtype()); }
bool isComplex() const { return imag_.has_value(); }

// Set dtype.
//
Expand Down
64 changes: 64 additions & 0 deletions libspu/device/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,31 @@ std::vector<spu::Value> IoClient::makeShares(const PtBufferView &bv,
return result;
}

if (bv.pt_type == PT_CF32 || bv.pt_type == PT_CF64) {
auto s_type = bv.pt_type == PT_CF32 ? PT_F32 : PT_F64;
auto offset = bv.pt_type == PT_CF32 ? sizeof(float) : sizeof(double);
// SPDLOG_INFO("bv.strides = {}", bv.strides);

Strides ds = bv.strides;
for (auto &s : ds) {
s *= 2;
}

PtBufferView real_view(bv.ptr, s_type, bv.shape, ds);
PtBufferView imag_view((std::byte *)bv.ptr + offset, s_type, bv.shape, ds);

auto r_shares = makeShares(real_view, vtype, owner_rank);
auto i_shares = makeShares(imag_view, vtype, owner_rank);

std::vector<spu::Value> result;
result.reserve(world_size_);
for (size_t idx = 0; idx < world_size_; ++idx) {
result.emplace_back(r_shares[idx].data(), i_shares[idx].data(),
r_shares[idx].dtype());
}
return result;
}

// encode to ring.
DataType dtype;
NdArrayRef encoded =
Expand All @@ -78,11 +103,50 @@ std::vector<spu::Value> IoClient::makeShares(const PtBufferView &bv,
return result;
}

template <typename T>
NdArrayRef combineComplex(const NdArrayRef &real, const NdArrayRef &imag,
const Type &complex_type) {
NdArrayRef ret(complex_type, real.shape());
NdArrayView<std::complex<T>> ret_v(ret);
NdArrayView<T> rv(real);
NdArrayView<T> iv(imag);
for (int64_t idx = 0; idx < real.numel(); ++idx) {
ret_v[idx].real(rv[idx]);
ret_v[idx].imag(iv[idx]);
}
return ret;
}

NdArrayRef IoClient::combineShares(absl::Span<spu::Value const> values) {
SPU_ENFORCE(values.size() == world_size_,
"wrong number of shares, got={}, expect={}", values.size(),
world_size_);

if (values.front().isComplex()) {
NdArrayRef real;
NdArrayRef imag;
{
std::vector<spu::Value> reals(values.size());
for (size_t idx = 0; idx < values.size(); ++idx) {
reals[idx] = Value(values[idx].data(), values[idx].dtype());
}
real = combineShares(reals);
}
{
std::vector<spu::Value> imags(values.size());
for (size_t idx = 0; idx < values.size(); ++idx) {
imags[idx] = Value(*values[idx].imag(), values[idx].dtype());
}
imag = combineShares(imags);
}
if (values.front().dtype() == DT_F32) {
return combineComplex<float>(real, imag, CF32);
} else {
SPU_ENFORCE(values.front().dtype() == DT_F64);
return combineComplex<double>(real, imag, CF64);
}
}

const size_t fxp_bits = config_.fxp_fraction_bits();
SPU_ENFORCE(fxp_bits != 0, "fxp should never be zero, please check default");

Expand Down
19 changes: 19 additions & 0 deletions libspu/device/pphlo/pphlo_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,25 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
removeValue(sscope, op.getOperand(), opts);
}

void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
mlir::pphlo::RealOp &op, const ExecutionOptions &opts) {
auto v = lookupValue(sscope, op.getOperand(), opts);
addValue(sscope, op.getResult(), kernel::hlo::Real(sctx, v), opts);
}

void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
mlir::pphlo::ImagOp &op, const ExecutionOptions &opts) {
auto v = lookupValue(sscope, op.getOperand(), opts);
addValue(sscope, op.getResult(), kernel::hlo::Imag(sctx, v), opts);
}

void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
mlir::pphlo::ComplexOp &op, const ExecutionOptions &opts) {
auto r = lookupValue(sscope, op.getLhs(), opts);
auto i = lookupValue(sscope, op.getRhs(), opts);
addValue(sscope, op.getResult(), kernel::hlo::Complex(sctx, r, i), opts);
}

#define DEFINE_UNIMPLEMENTED_OP(OpName) \
void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, \
mlir::pphlo::OpName &, const ExecutionOptions &opts) { \
Expand Down
3 changes: 3 additions & 0 deletions libspu/device/pphlo/pphlo_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class PPHloVerifier {
NO_VERIFY_DEFN(EpsilonOp)
NO_VERIFY_DEFN(CustomCallOp)
NO_VERIFY_DEFN(FreeOp)
NO_VERIFY_DEFN(RealOp)
NO_VERIFY_DEFN(ImagOp)
NO_VERIFY_DEFN(ComplexOp)

#undef NO_VERIFY_DEFN
};
Expand Down
Loading

0 comments on commit 9cd51e6

Please sign in to comment.