Skip to content

Commit

Permalink
Integrate StableHLO at openxla/stablehlo@b27ef13c
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 618971032
  • Loading branch information
sdasgup3 authored and TensorFlow MLIR Team committed Mar 25, 2024
1 parent a4428bb commit a7c35d9
Show file tree
Hide file tree
Showing 29 changed files with 486 additions and 77 deletions.
2 changes: 2 additions & 0 deletions stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ cc_library(
":reference_ops",
":reference_process",
":reference_scope",
":reference_tensor",
":reference_value",
":register",
"@llvm-project//llvm:Support",
Expand Down Expand Up @@ -1009,6 +1010,7 @@ cc_library(
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
Expand Down
3 changes: 2 additions & 1 deletion stablehlo/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ cc_library(
hdrs = [
"stablehlo/reference/Api.h",
],
strip_include_prefix = ".",
deps = [
":interpreter_ops",
":reference_configuration",
Expand All @@ -431,6 +430,7 @@ cc_library(
":reference_ops",
":reference_process",
":reference_scope",
":reference_tensor",
":reference_value",
":register",
"@llvm-project//llvm:Support",
Expand Down Expand Up @@ -962,6 +962,7 @@ cc_library(
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
Expand Down
24 changes: 24 additions & 0 deletions stablehlo/CMakePresets.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"version": 6,
"configurePresets": [
{
"name": "debug",
"displayName": "Debug w/ ccache",
"generator": "Ninja",
"binaryDir": "build/",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"LLVM_ENABLE_ASSERTIONS": "ON",
"LLVM_ENABLE_LLD": "ON",
"STABLEHLO_ENABLE_BINDINGS_PYTHON" : "OFF",
"STABLEHLO_ENABLE_SPLIT_DWARF": "ON",
"CMAKE_CXX_COMPILER_LAUNCHER": "ccache",
"CMAKE_CXX_COMPILER": "clang++",
"CMAKE_C_COMPILER_LAUNCHER": "ccache",
"CMAKE_C_COMPILER": "clang",
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"MLIR_DIR": "${sourceDir}/llvm-build/lib/cmake/mlir"
}
}
]
}
4 changes: 2 additions & 2 deletions stablehlo/WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "e371ada409b225ea990b5ac0d5cafea26a6046e1"
LLVM_COMMIT = "a4ca07f13b560b4f6fa5459eef7159e4f9ee9a6b"

LLVM_SHA256 = "aba42d644bf580345c3628da27ea68c45db3973f1740432eff64372ef26daf62"
LLVM_SHA256 = "fb936389d46b3ce7ee423c0d788e5359da8ce41cfe8996847719920c6f60b044"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 2 additions & 0 deletions stablehlo/build_tools/github_actions/lint_check_license.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ echo
SKIPPED_SUFFIXES=(
.clang-format
.gitignore
.json
.md
.mlir
.mlir.bc
Expand All @@ -67,6 +68,7 @@ for file in "${CHANGED_FILES[@]}"; do
for suffix in "${SKIPPED_SUFFIXES[@]}"; do
if [[ "$file" = *$suffix ]]; then
skip=1
break
fi
done
if (( skip )); then
Expand Down
53 changes: 53 additions & 0 deletions stablehlo/docs/ide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# IDE setup

You can find on this page some _opinionated_ IDE setup instructions.
Of course the best IDE is the one that _works for you_.

> If you have an improvement or recommendation to any of the setups, we welcome contributions.
## Visual Studio Code (vscode)

### CMake

Visual Studio Code (vscode) can work pretty well with the CMake build system.

The following extensions are recommended:

* [CMake Tools](https://marketplace.visualstudio.com/items?itemName=ms-vscode.cmake-tools)
* [clangd](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd)

> Note: Installing the `clangd` extension will inform you that you
will need to disable the default _intellisense_ extension.
This is fine, as `clangd` will provide the same functionality.

We include a [CmakePresets.json](../CMakePresets.json) file in the root of the repository.
This file is used by the `CMake Tools` extension to provide a list of _presets_ that
can be used to configure the build. The `CMake Tools` extension will automatically
detect this file and provide the presets.

Additionally, all the configured presets generate the `compile_commands.json` file
in the build directory which will then be picked up by `clangd`.

We recommend additionally setting the following in your `.vscode/settings.json` file:

```json
{
"files.exclude": {
"**/bazel-*": true
}
}
```

## Vim

### LLVM/MLIR settings

Check out the official instructions for [LLVM](https://github.com/llvm/llvm-project/blob/main/llvm/utils/vim/README)
and [MLIR](https://github.com/llvm/llvm-project/blob/main/mlir/utils/vim/README)
settings to enable syntax highlighting and other goodies.

### IDE-like features

Check out the official [documentation for
`clangd`](https://releases.llvm.org/9.0.1/tools/clang/tools/extra/docs/clangd/Installation.html)
to enable features like autocompletion, go to definition, etc.
17 changes: 17 additions & 0 deletions stablehlo/rfcs/20230609-extensibility.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# [RFC] StableHLO Extensibility

Status: Approved<br/>
Initial version: 6/9/2023<br/>
Last updated: 3/8/2024<br/>
Discussion thread: [openxla-discuss](https://groups.google.com/a/openxla.org/g/openxla-discuss/c/Ao5K8fvXoEk/m/OaddRrgyAgAJ).

## Summary

For full details and RFC discussion, see:
[[RFC] StableHLO Extensibility](https://docs.google.com/document/d/1bSyyLA-p1F7KjZgjo563F1WFsPwcZc4eaH5WyQfbsi0/edit#heading=h.kfv34azf3j5k).

In its role as a portability layer between ML frameworks and ML compilers,
StableHLO provides a common vocabulary of well-understood ops along with
compatibility guarantees for them. However, this all works only for a closed set
of ops within the StableHLO dialect. In this document, we propose to offer a
mechanism to create portable abstractions of StableHLO ops.
15 changes: 15 additions & 0 deletions stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,21 @@ func.func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {

// -----

// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @reverse_dynamic
func.func @reverse_dynamic(%input: tensor<?x3xf32>) -> tensor<?x3xf32> {
%result = "stablehlo.reverse"(%input) {
dimensions = array<i64: 1>, someattr
} : (tensor<?x3xf32>) -> tensor<?x3xf32>
func.return %result : tensor<?x3xf32>
}
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-SAME: {someattr}

// -----

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType,
// Ask the op for its output shape.
auto shapeSource = cast<InferShapedTypeOpInterface>(op);
SmallVector<Value, 1> reifiedShapes;
(void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes);
assert(succeeded(
shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes)) &&
"could not reify");
assert(reifiedShapes.size() == 1 && "Expected one reified result");
// Construct sizes for the required dimensions.
for (const auto &en : llvm::enumerate(resultType.getShape())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,18 +354,20 @@ struct DataMovementOpConverter : OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite(
OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (failed(verifyHloOpBufferOrTensorSemantics(op))) return failure();
if (failed(verifyHloOpBufferOrTensorSemantics(op)))
return rewriter.notifyMatchFailure(
op, "failed to verify hlo buffer or tensor semantics");

ShapedType resultType = getHloOpResultType(op);
resultType =
this->getTypeConverter()->template convertType<ShapedType>(resultType);
if (!resultType) {
if (!resultType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
}

SmallVector<AffineMap, 2> indexingMaps =
Derived::getIndexingMaps(op, &rewriter);
if (indexingMaps.empty()) return failure();
if (indexingMaps.empty())
return rewriter.notifyMatchFailure(op, "could not derive indexing maps");

int64_t nloops = resultType.getRank();
Location loc = op.getLoc();
Expand Down
7 changes: 7 additions & 0 deletions stablehlo/stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,13 @@ LogicalResult ReverseOp::inferReturnTypeComponents(
inferredReturnShapes);
}

LogicalResult ReverseOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::hlo::deriveShapeFromOperand(
&builder, getOperation(), operands.front(), &reifiedReturnShapes);
}

//===----------------------------------------------------------------------===//
// RngBitGeneratorOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2730,7 +2730,7 @@ def StableHLO_SortOp : StableHLO_Op<"sort",
let hasVerifier = 1;
}

def StableHLO_ReverseOp: StableHLO_Op<"reverse",
def StableHLO_ReverseOp: StableHLO_ShapedInterfaceOp<"reverse",
[Pure, HLO_CompatibleOperandsAndResultType /*reverse_c1*/]> {
let summary = "Reverse operation";
let description = [{
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(0, 19, 1); }
static Version getCurrentVersion() { return Version(0, 19, 2); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
1 change: 1 addition & 0 deletions stablehlo/stablehlo/integrations/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ declare_mlir_python_extension(StablehloPythonExtensions.Main
StablehloCAPI
PRIVATE_LINK_LIBS
StablehloPortableApi
StablehloReferenceApi
StablehloSerialization
LLVMSupport
)
Expand Down
38 changes: 37 additions & 1 deletion stablehlo/stablehlo/integrations/python/StablehloModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <vector>

#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "stablehlo/dialect/Serialization.h"
#include "stablehlo/integrations/c/StablehloAttributes.h"
#include "stablehlo/integrations/c/StablehloDialect.h"
#include "stablehlo/integrations/c/StablehloTypes.h"
#include "stablehlo/integrations/python/PortableApi.h"
#include "stablehlo/reference/Api.h"

namespace py = pybind11;

Expand Down Expand Up @@ -483,6 +487,38 @@ PYBIND11_MODULE(_stablehlo, m) {
//
mlir::stablehlo::AddPortableApi(m);

//
// Reference APIs
//
m.def(
"eval_module",
[](MlirModule module,
std::vector<MlirAttribute> &args) -> std::vector<MlirAttribute> {
std::vector<mlir::DenseElementsAttr> inputs;
for (auto arg : args) {
auto attr = unwrap(arg).dyn_cast<mlir::DenseElementsAttr>();
if (!attr) {
PyErr_SetString(PyExc_ValueError,
"input args must be DenseElementsAttr");
return {};
}
inputs.push_back(attr);
}

mlir::stablehlo::InterpreterConfiguration config;
auto results =
mlir::stablehlo::evalModule(unwrap(module), inputs, config);
if (failed(results)) {
PyErr_SetString(PyExc_ValueError, "interpreter failed");
return {};
}

std::vector<MlirAttribute> pyResults;
for (auto res : *results) pyResults.push_back(wrap(res));
return pyResults;
},
py::arg("module"), py::arg("args"));

//
// Serialization APIs.
//
Expand Down Expand Up @@ -515,5 +551,5 @@ PYBIND11_MODULE(_stablehlo, m) {

return {module.release()};
},
py::arg("module"), py::arg("target"));
py::arg("context"), py::arg("artifact"));
}
Loading

0 comments on commit a7c35d9

Please sign in to comment.