Skip to content

Commit

Permalink
[mlir python] Port Python core code to nanobind.
Browse files Browse the repository at this point in the history
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better
than I can, but my primary motivation for this change is to improve MLIR
IR construction time from JAX.

For a complicated Google-internal LLM model in JAX, this change improves the MLIR
lowering time by around 5s (out of around 30s), which is a significant
speedup for simply switching binding frameworks.

To a large extent, this is a mechanical change, for instance changing pybind11::
to nanobind::.

Notes:
* this PR needs wjakob/nanobind#806 to land in
  nanobind first. Without that fix, importing the MLIR modules will
  fail.
* this PR does not port the in-tree dialect extension modules. They can
  be ported in a future PR.
* I removed the py::sibling() annotations from def_static and def_class
  in PybindAdapters.h. These ask pybind11 to try to form an overload
  with an existing method, but it's not possible to form mixed
  pybind11/nanobind overloads this ways and the parent class is now defined in
  nanobind. Better solutions may be possible here.
* nanobind does not contain an exact equivalent of pybind11's buffer
  protocol support. It was not hard to add a nanobind implementation of
  a similar API.
* nanobind is pickier about casting to std::vector<bool>, expecting that
  the input is a sequence of bool types, not truthy values. In a couple
  of places I added code to support truthy values during casting.
* nanobind distinguishes bytes (nb::bytes) from strings (e.g.,
  std::string). This required nb::bytes overloads in a few places.
  • Loading branch information
hawkinsp committed Dec 4, 2024
1 parent 5d8eabc commit 0f813dc
Show file tree
Hide file tree
Showing 20 changed files with 2,102 additions and 1,853 deletions.
12 changes: 6 additions & 6 deletions mlir/include/mlir/Bindings/Python/Diagnostics.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
#include <cassert>
#include <string>

#include "llvm/ADT/StringRef.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "llvm/ADT/StringRef.h"

namespace mlir {
namespace python {

/// RAII scope intercepting all diagnostics into a string. The message must be
/// checked before this goes out of scope.
class CollectDiagnosticsToStringScope {
public:
public:
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
/*deleteUserData=*/nullptr);
Expand All @@ -34,7 +34,7 @@ class CollectDiagnosticsToStringScope {

[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }

private:
private:
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
auto printer = +[](MlirStringRef message, void *data) {
*static_cast<std::string *>(data) +=
Expand All @@ -53,7 +53,7 @@ class CollectDiagnosticsToStringScope {
std::string errorMessage = "";
};

} // namespace python
} // namespace mlir
} // namespace python
} // namespace mlir

#endif // MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
#endif // MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
2 changes: 1 addition & 1 deletion mlir/include/mlir/Bindings/Python/IRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H

#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace mlir {

Expand Down
47 changes: 22 additions & 25 deletions mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@

#include <cstdint>

#include "llvm/ADT/Twine.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "llvm/ADT/Twine.h"

// Raw CAPI type casters need to be declared before use, so always include them
// first.
Expand Down Expand Up @@ -233,8 +233,7 @@ struct type_caster<MlirOperation> {
}
static handle from_cpp(MlirOperation v, rv_policy,
cleanup_list *cleanup) noexcept {
if (v.ptr == nullptr)
return nanobind::none();
if (v.ptr == nullptr) return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonOperationToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
Expand All @@ -255,8 +254,7 @@ struct type_caster<MlirValue> {
}
static handle from_cpp(MlirValue v, rv_policy,
cleanup_list *cleanup) noexcept {
if (v.ptr == nullptr)
return nanobind::none();
if (v.ptr == nullptr) return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonValueToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
Expand Down Expand Up @@ -289,8 +287,7 @@ struct type_caster<MlirTypeID> {
}
static handle from_cpp(MlirTypeID v, rv_policy,
cleanup_list *cleanup) noexcept {
if (v.ptr == nullptr)
return nanobind::none();
if (v.ptr == nullptr) return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
Expand Down Expand Up @@ -321,8 +318,8 @@ struct type_caster<MlirType> {
}
};

} // namespace detail
} // namespace nanobind
} // namespace detail
} // namespace nanobind

namespace mlir {
namespace python {
Expand All @@ -340,7 +337,7 @@ namespace nanobind_adaptors {
/// (plus a fair amount of extra curricular poking)
/// TODO: If this proves useful, see about including it in nanobind.
class pure_subclass {
public:
public:
pure_subclass(nanobind::handle scope, const char *derivedClassName,
const nanobind::object &superClass) {
nanobind::object pyType =
Expand Down Expand Up @@ -382,7 +379,7 @@ class pure_subclass {
"function pointer");
nanobind::object cf = nanobind::cpp_function(
std::forward<Func>(f),
nanobind::name(name), // nanobind::scope(thisClass),
nanobind::name(name), // nanobind::scope(thisClass),
extra...);
thisClass.attr(name) = cf;
return *this;
Expand All @@ -396,7 +393,7 @@ class pure_subclass {
"function pointer");
nanobind::object cf = nanobind::cpp_function(
std::forward<Func>(f),
nanobind::name(name), // nanobind::scope(thisClass),
nanobind::name(name), // nanobind::scope(thisClass),
extra...);
thisClass.attr(name) =
nanobind::borrow<nanobind::object>(PyClassMethod_New(cf.ptr()));
Expand All @@ -405,15 +402,15 @@ class pure_subclass {

nanobind::object get_class() const { return thisClass; }

protected:
protected:
nanobind::object superClass;
nanobind::object thisClass;
};

/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting
/// constructor and type checking methods.
class mlir_attribute_subclass : public pure_subclass {
public:
public:
using IsAFunctionTy = bool (*)(MlirAttribute);
using GetTypeIDFunctionTy = MlirTypeID (*)();

Expand Down Expand Up @@ -445,7 +442,7 @@ class mlir_attribute_subclass : public pure_subclass {
// have no additional members, we can just return the instance thus created
// without amending it.
std::string captureTypeName(
typeClassName); // As string in case if typeClassName is not static.
typeClassName); // As string in case if typeClassName is not static.
nanobind::object newCf = nanobind::cpp_function(
[superCls, isaFunction, captureTypeName](
nanobind::object cls, nanobind::object otherAttribute) {
Expand Down Expand Up @@ -491,7 +488,7 @@ class mlir_attribute_subclass : public pure_subclass {
/// Creates a custom subclass of mlir.ir.Type, implementing a casting
/// constructor and type checking methods.
class mlir_type_subclass : public pure_subclass {
public:
public:
using IsAFunctionTy = bool (*)(MlirType);
using GetTypeIDFunctionTy = MlirTypeID (*)();

Expand Down Expand Up @@ -523,7 +520,7 @@ class mlir_type_subclass : public pure_subclass {
// have no additional members, we can just return the instance thus created
// without amending it.
std::string captureTypeName(
typeClassName); // As string in case if typeClassName is not static.
typeClassName); // As string in case if typeClassName is not static.
nanobind::object newCf = nanobind::cpp_function(
[superCls, isaFunction, captureTypeName](nanobind::object cls,
nanobind::object otherType) {
Expand Down Expand Up @@ -573,7 +570,7 @@ class mlir_type_subclass : public pure_subclass {
/// Creates a custom subclass of mlir.ir.Value, implementing a casting
/// constructor and type checking methods.
class mlir_value_subclass : public pure_subclass {
public:
public:
using IsAFunctionTy = bool (*)(MlirValue);

/// Subclasses by looking up the super-class dynamically.
Expand Down Expand Up @@ -601,7 +598,7 @@ class mlir_value_subclass : public pure_subclass {
// have no additional members, we can just return the instance thus created
// without amending it.
std::string captureValueName(
valueClassName); // As string in case if valueClassName is not static.
valueClassName); // As string in case if valueClassName is not static.
nanobind::object newCf = nanobind::cpp_function(
[superCls, isaFunction, captureValueName](nanobind::object cls,
nanobind::object otherValue) {
Expand Down Expand Up @@ -629,12 +626,12 @@ class mlir_value_subclass : public pure_subclass {
}
};

} // namespace nanobind_adaptors
} // namespace nanobind_adaptors

/// RAII scope intercepting all diagnostics into a string. The message must be
/// checked before this goes out of scope.
class CollectDiagnosticsToStringScope {
public:
public:
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
/*deleteUserData=*/nullptr);
Expand All @@ -646,7 +643,7 @@ class CollectDiagnosticsToStringScope {

[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }

private:
private:
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
auto printer = +[](MlirStringRef message, void *data) {
*static_cast<std::string *>(data) +=
Expand All @@ -665,7 +662,7 @@ class CollectDiagnosticsToStringScope {
std::string errorMessage = "";
};

} // namespace python
} // namespace mlir
} // namespace python
} // namespace mlir

#endif // MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
#endif // MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
10 changes: 4 additions & 6 deletions mlir/include/mlir/Bindings/Python/PybindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,8 @@ class pure_subclass {
static_assert(!std::is_member_function_pointer<Func>::value,
"def_staticmethod(...) called with a non-static member "
"function pointer");
py::cpp_function cf(
std::forward<Func>(f), py::name(name), py::scope(thisClass),
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
py::cpp_function cf(std::forward<Func>(f), py::name(name),
py::scope(thisClass), extra...);
thisClass.attr(cf.name()) = py::staticmethod(cf);
return *this;
}
Expand All @@ -387,9 +386,8 @@ class pure_subclass {
static_assert(!std::is_member_function_pointer<Func>::value,
"def_classmethod(...) called with a non-static member "
"function pointer");
py::cpp_function cf(
std::forward<Func>(f), py::name(name), py::scope(thisClass),
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
py::cpp_function cf(std::forward<Func>(f), py::name(name),
py::scope(thisClass), extra...);
thisClass.attr(cf.name()) =
py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
return *this;
Expand Down
51 changes: 25 additions & 26 deletions mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H

#include "PybindUtils.h"
#include <optional>
#include <string>
#include <vector>

#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "NanobindUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"

#include <optional>
#include <string>
#include <vector>
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"

namespace mlir {
namespace python {
Expand Down Expand Up @@ -57,71 +56,71 @@ class PyGlobals {
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
pybind11::function pyFunc,
nanobind::callable pyFunc,
bool replace = false);

/// Adds a user-friendly type caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
bool replace = false);

/// Adds a user-friendly value caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerValueCaster(MlirTypeID mlirTypeID,
pybind11::function valueCaster,
nanobind::callable valueCaster,
bool replace = false);

/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerDialectImpl(const std::string &dialectNamespace,
pybind11::object pyClass);
nanobind::object pyClass);

/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
pybind11::object pyClass, bool replace = false);
nanobind::object pyClass, bool replace = false);

/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
lookupAttributeBuilder(const std::string &attributeKind);
std::optional<nanobind::callable> lookupAttributeBuilder(
const std::string &attributeKind);

/// Returns the custom type caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);

/// Returns the custom value caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);

/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
lookupDialectClass(const std::string &dialectNamespace);
std::optional<nanobind::object> lookupDialectClass(
const std::string &dialectNamespace);

/// Looks up a registered operation class (deriving from OpView) by operation
/// name. Note that this may trigger a load of the dialect, which can
/// arbitrarily re-enter.
std::optional<pybind11::object>
lookupOperationClass(llvm::StringRef operationName);
std::optional<nanobind::object> lookupOperationClass(
llvm::StringRef operationName);

private:
private:
static PyGlobals *instance;
/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.
llvm::StringMap<pybind11::object> dialectClassMap;
llvm::StringMap<nanobind::object> dialectClassMap;
/// Map of full operation name to external operation class object.
llvm::StringMap<pybind11::object> operationClassMap;
llvm::StringMap<nanobind::object> operationClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<pybind11::object> attributeBuilderMap;
llvm::StringMap<nanobind::callable> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
/// Map of MlirTypeID to custom value caster.
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
Expand Down
Loading

0 comments on commit 0f813dc

Please sign in to comment.