Skip to content

Commit

Permalink
Repo sync (#357)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #

## Possible side effects?

- Performance:

- Backward compatibility:
  • Loading branch information
anakinxc authored Sep 22, 2023
1 parent 56c9414 commit 0e1cc5f
Show file tree
Hide file tree
Showing 29 changed files with 659 additions and 131 deletions.
2 changes: 0 additions & 2 deletions .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,3 @@ body:
description: Please copy and paste any relevant log output. This will be
automatically formatted into code, so no need for backticks.
render: Shell


1 change: 1 addition & 0 deletions libspu/compiler/core/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ void Core::buildPipeline(mlir::PassManager *pm) {
}
optPM.addPass(mlir::pphlo::createDecomposeComparisonPass());
optPM.addPass(mlir::pphlo::createDecomposeMinMaxPass());
optPM.addPass(mlir::pphlo::createSortLowering());

if (!options.disable_sqrt_plus_epsilon_rewrite()) {
optPM.addPass(mlir::pphlo::createOptimizeSqrtPlusEps());
Expand Down
1 change: 0 additions & 1 deletion libspu/compiler/front_end/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ spu_cc_library(
"@xla//xla/service:scatter_expander",
"@xla//xla/service:slice_sinker",
"@xla//xla/service:sort_simplifier",
"@xla//xla/service:stable_sort_expander",
"@xla//xla/service:triangular_solve_expander",
"@xla//xla/service:while_loop_constant_sinking",
"@xla//xla/service:while_loop_simplifier",
Expand Down
3 changes: 0 additions & 3 deletions libspu/compiler/front_end/hlo_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
#include "xla/service/scatter_expander.h"
#include "xla/service/slice_sinker.h"
#include "xla/service/sort_simplifier.h"
#include "xla/service/stable_sort_expander.h"
#include "xla/service/triangular_solve_expander.h"
#include "xla/service/tuple_simplifier.h"
#include "xla/service/while_loop_constant_sinking.h"
Expand Down Expand Up @@ -105,8 +104,6 @@ void runHloPasses(xla::HloModule *module) {

pipeline.AddPass<Convolution4DExpander>();

pipeline.AddPass<StableSortExpander>();

// After canonicalization, there may be more batch dots that can be
// simplified.
pipeline.AddPass<BatchDotSimplification>();
Expand Down
13 changes: 13 additions & 0 deletions libspu/compiler/passes/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,18 @@ spu_cc_library(
],
)

spu_cc_library(
name = "sort_lowering",
srcs = ["sort_lowering.cc"],
hdrs = ["passes.h"],
deps = [
":pass_details",
"//libspu/dialect:pphlo_dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils",
],
)

spu_cc_library(
name = "all_passes",
hdrs = ["register_passes.h"],
Expand All @@ -263,5 +275,6 @@ spu_cc_library(
":optimize_sqrt_plus_eps",
":reduce_truncation",
":rewrite_div_sqrt_patterns",
":sort_lowering",
],
)
2 changes: 2 additions & 0 deletions libspu/compiler/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ createOptimizeDenominatorWithBroadcast();

std::unique_ptr<OperationPass<func::FuncOp>> createInsertDeallocationOp();

std::unique_ptr<OperationPass<func::FuncOp>> createSortLowering();

} // namespace pphlo

} // namespace mlir
6 changes: 6 additions & 0 deletions libspu/compiler/passes/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,9 @@ def InsertDeallocation: Pass<"insert-deallocation", "func::FuncOp"> {
let constructor = "createInsertDeallocationOp()";
let dependentDialects = ["pphlo::PPHloDialect"];
}

def SortLowering: Pass<"sort-lowering", "func::FuncOp"> {
let summary = "Lower some simple sort to simple sort op";
let constructor = "createSortLowering()";
let dependentDialects = ["pphlo::PPHloDialect"];
}
97 changes: 97 additions & 0 deletions libspu/compiler/passes/sort_lowering.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2023 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "spdlog/spdlog.h"

#include "libspu/compiler/passes/pass_details.h"
#include "libspu/compiler/passes/passes.h"
#include "libspu/dialect/pphlo_ops.h"
#include "libspu/dialect/pphlo_types.h"

namespace mlir::pphlo {

namespace {

struct SortConversion : public OpRewritePattern<SortOp> {
public:
explicit SortConversion(MLIRContext *context)
: OpRewritePattern<SortOp>(context) {}

LogicalResult matchAndRewrite(SortOp op,
PatternRewriter &rewriter) const override {
auto &comp = op.getComparator();
if (op->getNumOperands() == 1) {
// When there is only one operand, stable or not seems irrelevant
op.setIsStable(false);
}

// If has a single instruction comparator, check if it's a simple sort.
if (comp.hasOneBlock() &&
llvm::hasSingleElement(comp.front().without_terminator())) {
auto &inst = comp.front().front();
// Single instruction comparator.
if (mlir::isa<pphlo::LessOp>(inst) || mlir::isa<pphlo::GreaterOp>(inst)) {
mlir::IntegerAttr direction;
if (mlir::isa<pphlo::GreaterOp>(inst)) {
// descent
direction = rewriter.getI32IntegerAttr(
static_cast<int32_t>(SortDirection::DES));
} else {
// ascent
direction = rewriter.getI32IntegerAttr(
static_cast<int32_t>(SortDirection::ASC));
}
auto lhs_idx =
inst.getOperand(0).dyn_cast<mlir::BlockArgument>().getArgNumber();
auto rhs_idx =
inst.getOperand(1).dyn_cast<mlir::BlockArgument>().getArgNumber();
// FIXME: If the comparator is using operands other than the first one,
// we should just reorder operands instead of bailout
if (lhs_idx != 0 || rhs_idx != 1) {
return failure();
}

rewriter.replaceOpWithNewOp<pphlo::SimpleSortOp>(
op, op.getResultTypes(), op.getOperands(), op.getDimensionAttr(),
direction);
return success();
}
}
return failure();
}
};

struct SortLowering : public SortLoweringBase<SortLowering> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateOwningPatterns(&patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

private:
static void populateOwningPatterns(RewritePatternSet *patterns,
MLIRContext *ctx) {
patterns->insert<SortConversion>(ctx);
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> createSortLowering() {
return std::make_unique<SortLowering>();
}

} // namespace mlir::pphlo
11 changes: 11 additions & 0 deletions libspu/compiler/tests/sort_lowering.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: mlir-pphlo-opt --sort-lowering --split-input-file %s | FileCheck %s

func.func @main(%arg0: tensor<10x!pphlo.sec<f32>>) -> tensor<10x!pphlo.sec<f32>> {
// CHECK: simple_sort
%0 = "pphlo.sort"(%arg0) ({
^bb0(%arg1: tensor<!pphlo.sec<f32>>, %arg2: tensor<!pphlo.sec<f32>>):
%1 = "pphlo.less"(%arg1, %arg2) : (tensor<!pphlo.sec<f32>>, tensor<!pphlo.sec<f32>>) -> tensor<!pphlo.sec<i1>>
"pphlo.return"(%1) : (tensor<!pphlo.sec<i1>>) -> ()
}) {dimension = 0 : i64, is_stable = false} : (tensor<10x!pphlo.sec<f32>>) -> tensor<10x!pphlo.sec<f32>>
return %0 : tensor<10x!pphlo.sec<f32>>
}
4 changes: 3 additions & 1 deletion libspu/core/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ class KernelEvalContext final {
Type, // type of type
uint128_t, // ring constant
int64_t, //
SignType //
SignType, //
std::vector<Value>, // for sort
absl::Span<Value const> // for sort
>;

SPUContext* sctx_;
Expand Down
22 changes: 20 additions & 2 deletions libspu/core/ndarray_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,29 @@ class NdArrayRef {
NdArrayRef(const NdArrayRef& other) = default;
NdArrayRef(NdArrayRef&& other) = default;
NdArrayRef& operator=(const NdArrayRef& other) = default;

#ifndef NDEBUG
// GCC 11.4 with -O1 is not happy with default assign operator when using
// std::reverse...
NdArrayRef& operator=(NdArrayRef&& other) noexcept {
if (this != &other) {
std::swap(this->buf_, other.buf_);
std::swap(this->eltype_, other.eltype_);
std::swap(this->shape_, other.shape_);
std::swap(this->strides_, other.strides_);
std::swap(this->offset_, other.offset_);
std::swap(this->use_fast_indexing_, other.use_fast_indexing_);
std::swap(this->fast_indexing_stride_, other.fast_indexing_stride_);
}
return *this;
}
#else
NdArrayRef& operator=(NdArrayRef&& other) = default;
#endif

bool operator==(const NdArrayRef& other) const {
return buf_ == other.buf_ && shape_ == other.shape_ &&
strides_ == other.strides_ && offset_ == other.offset_;
return shape_ == other.shape_ && strides_ == other.strides_ &&
offset_ == other.offset_ && buf_ == other.buf_;
}

bool operator!=(const NdArrayRef& other) const { return !(*this == other); }
Expand Down
2 changes: 1 addition & 1 deletion libspu/core/trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ inline std::ostream& operator<<(std::ostream& os,
}

inline std::ostream& operator<<(std::ostream& os,
const absl::Span<int64_t const>& indices) {
absl::Span<int64_t const> indices) {
os << fmt::format("{{{}}}", fmt::join(indices, ","));
return os;
}
Expand Down
26 changes: 26 additions & 0 deletions libspu/device/pphlo/pphlo_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,32 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
}
}

void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
mlir::pphlo::SimpleSortOp &op, const ExecutionOptions &opts) {
auto sort_dim = op.getDimension();
std::vector<spu::Value> inputs(op->getNumOperands());
for (size_t idx = 0; idx < inputs.size(); ++idx) {
inputs[idx] = lookupValue(sscope, op->getOperand(idx), opts);
}

kernel::hal::SortDirection direction;
if (op.getSortDirectionAttr().getInt() ==
static_cast<int>(mlir::pphlo::SortDirection::ASC)) {
direction = kernel::hal::SortDirection::Ascending;
} else if (op.getSortDirectionAttr().getInt() ==
static_cast<int>(mlir::pphlo::SortDirection::DES)) {
direction = kernel::hal::SortDirection::Descending;
} else {
SPU_THROW("Should not reach here");
}

auto ret = kernel::hlo::SimpleSort(sctx, inputs, sort_dim, direction);

for (int64_t idx = 0; idx < op->getNumResults(); ++idx) {
addValue(sscope, op->getResult(idx), ret[idx], opts);
}
}

void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
mlir::pphlo::SelectAndScatterOp &op,
const ExecutionOptions &opts) {
Expand Down
1 change: 1 addition & 0 deletions libspu/device/pphlo/pphlo_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class PPHloVerifier {
NO_VERIFY_DEFN(RealOp)
NO_VERIFY_DEFN(ImagOp)
NO_VERIFY_DEFN(ComplexOp)
NO_VERIFY_DEFN(SimpleSortOp)

#undef NO_VERIFY_DEFN
};
Expand Down
16 changes: 16 additions & 0 deletions libspu/dialect/pphlo_base_enums.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,20 @@ def PPHLO_VisibilityAttr
let genSpecializedAttr = 0;
}

//===----------------------------------------------------------------------===//
// Sort direction enum definitions.
//===----------------------------------------------------------------------===//
def PPHLO_SORT_DIRECTION_ASC : I32EnumAttrCase<"ASC", 0>;
def PPHLO_SORT_DIRECTION_DES : I32EnumAttrCase<"DES", 1>;

def PPHLO_SortDirection : I32EnumAttr<"SortDirection",
"Which mode to sort.",
[
PPHLO_SORT_DIRECTION_ASC,
PPHLO_SORT_DIRECTION_DES
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::pphlo";
}

#endif // SPU_DIALECT_PPHLO_BASE_ENUMS
16 changes: 16 additions & 0 deletions libspu/dialect/pphlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,22 @@ def PPHLO_SortOp
CArg<"bool", "false">:$is_stable)>];
}

def PPHLO_SimpleSortOp
: PPHLO_Op<"simple_sort", [RecursiveMemoryEffects, SameOperandsAndResultShape]> {
let summary = "Sort operator";
let description = [{
Sorts the given `operands` at the given `dimension` with the given
`mode`.
}];
let arguments = (ins
Variadic<PPHLO_Tensor>:$operands,
DefaultValuedAttr<I64Attr, "-1">:$dimension,
PPHLO_SortDirection: $sort_direction
);

let results = (outs Variadic<PPHLO_Tensor>);
}

def PPHLO_ReverseOp
: PPHLO_Op<"reverse", [Pure, SameOperandsAndResultType]> {
let summary = "Reverse operator";
Expand Down
6 changes: 6 additions & 0 deletions libspu/kernel/hal/prot_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ Value _trunc_s(SPUContext* ctx, const Value& in, size_t bits, SignType sign) {
return mpc::trunc_s(ctx, in, bits, sign);
}

std::vector<Value> _sort_s(SPUContext* ctx, absl::Span<Value const> x) {
SPU_TRACE_HAL_DISP(ctx, x.size());
// FIXME(jimi): formalize mpc sort api
return dynDispatch<std::vector<Value>>(ctx, "sort_a", x);
}

MAP_UNARY_OP(p2s)
MAP_UNARY_OP(s2p)
MAP_UNARY_OP(not_p)
Expand Down
2 changes: 2 additions & 0 deletions libspu/kernel/hal/prot_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ Value _make_p(SPUContext* ctx, uint128_t init, const Shape& shape);
Value _rand_p(SPUContext* ctx, const Shape& shape);
Value _rand_s(SPUContext* ctx, const Shape& shape);

std::vector<Value> _sort_s(SPUContext* ctx, absl::Span<Value const> x);

// NOLINTEND(readability-identifier-naming)

} // namespace spu::kernel::hal
Loading

0 comments on commit 0e1cc5f

Please sign in to comment.