Skip to content

Commit

Permalink
[mlir python] Port Python core code to nanobind. (#120473)
Browse files Browse the repository at this point in the history
Relands #118583, with a fix for Python 3.8 compatibility. It was not
possible to set the buffer protocol accessers via slots in Python 3.8.

Why? https://nanobind.readthedocs.io/en/latest/why.html says it better
than I can, but my primary motivation for this change is to improve MLIR
IR construction time from JAX.

For a complicated Google-internal LLM model in JAX, this change improves
the MLIR
lowering time by around 5s (out of around 30s), which is a significant
speedup for simply switching binding frameworks.

To a large extent, this is a mechanical change, for instance changing
`pybind11::` to `nanobind::`.

Notes:
* this PR needs Nanobind 2.4.0, because it needs a bug fix
(wjakob/nanobind#806) that landed in that
release.
* this PR does not port the in-tree dialect extension modules. They can
be ported in a future PR.
* I removed the py::sibling() annotations from def_static and def_class
in `PybindAdapters.h`. These ask pybind11 to try to form an overload
with an existing method, but it's not possible to form mixed
pybind11/nanobind overloads this ways and the parent class is now
defined in nanobind. Better solutions may be possible here.
* nanobind does not contain an exact equivalent of pybind11's buffer
protocol support. It was not hard to add a nanobind implementation of a
similar API.
* nanobind is pickier about casting to std::vector<bool>, expecting that
the input is a sequence of bool types, not truthy values. In a couple of
places I added code to support truthy values during casting.
* nanobind distinguishes bytes (`nb::bytes`) from strings (e.g.,
`std::string`). This required nb::bytes overloads in a few places.
  • Loading branch information
hawkinsp authored Dec 19, 2024
1 parent 89b34ec commit b56d1ec
Show file tree
Hide file tree
Showing 23 changed files with 1,898 additions and 1,583 deletions.
2 changes: 1 addition & 1 deletion mlir/cmake/modules/MLIRDetectPythonEnv.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ macro(mlir_configure_python_dev_packages)
"extension = '${PYTHON_MODULE_EXTENSION}")

mlir_detect_nanobind_install()
find_package(nanobind 2.2 CONFIG REQUIRED)
find_package(nanobind 2.4 CONFIG REQUIRED)
message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
"suffix = '${PYTHON_MODULE_SUFFIX}', "
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Bindings/Python/IRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H

#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace mlir {

Expand Down
26 changes: 13 additions & 13 deletions mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) {
/// Casts object <-> MlirAffineMap.
template <>
struct type_caster<MlirAffineMap> {
NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"));
NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToAffineMap(capsule.ptr());
Expand All @@ -87,7 +87,7 @@ struct type_caster<MlirAffineMap> {
/// Casts object <-> MlirAttribute.
template <>
struct type_caster<MlirAttribute> {
NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"));
NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToAttribute(capsule.ptr());
Expand All @@ -108,7 +108,7 @@ struct type_caster<MlirAttribute> {
/// Casts object -> MlirBlock.
template <>
struct type_caster<MlirBlock> {
NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"));
NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToBlock(capsule.ptr());
Expand All @@ -119,7 +119,7 @@ struct type_caster<MlirBlock> {
/// Casts object -> MlirContext.
template <>
struct type_caster<MlirContext> {
NB_TYPE_CASTER(MlirContext, const_name("MlirContext"));
NB_TYPE_CASTER(MlirContext, const_name("MlirContext"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
if (src.is_none()) {
// Gets the current thread-bound context.
Expand All @@ -139,7 +139,7 @@ struct type_caster<MlirContext> {
/// Casts object <-> MlirDialectRegistry.
template <>
struct type_caster<MlirDialectRegistry> {
NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"));
NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
Expand All @@ -159,7 +159,7 @@ struct type_caster<MlirDialectRegistry> {
/// Casts object <-> MlirLocation.
template <>
struct type_caster<MlirLocation> {
NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"));
NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
if (src.is_none()) {
// Gets the current thread-bound context.
Expand All @@ -185,7 +185,7 @@ struct type_caster<MlirLocation> {
/// Casts object <-> MlirModule.
template <>
struct type_caster<MlirModule> {
NB_TYPE_CASTER(MlirModule, const_name("MlirModule"));
NB_TYPE_CASTER(MlirModule, const_name("MlirModule"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToModule(capsule.ptr());
Expand All @@ -206,7 +206,7 @@ struct type_caster<MlirModule> {
template <>
struct type_caster<MlirFrozenRewritePatternSet> {
NB_TYPE_CASTER(MlirFrozenRewritePatternSet,
const_name("MlirFrozenRewritePatternSet"));
const_name("MlirFrozenRewritePatternSet"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
Expand All @@ -225,7 +225,7 @@ struct type_caster<MlirFrozenRewritePatternSet> {
/// Casts object <-> MlirOperation.
template <>
struct type_caster<MlirOperation> {
NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"));
NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToOperation(capsule.ptr());
Expand All @@ -247,7 +247,7 @@ struct type_caster<MlirOperation> {
/// Casts object <-> MlirValue.
template <>
struct type_caster<MlirValue> {
NB_TYPE_CASTER(MlirValue, const_name("MlirValue"));
NB_TYPE_CASTER(MlirValue, const_name("MlirValue"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToValue(capsule.ptr());
Expand All @@ -270,7 +270,7 @@ struct type_caster<MlirValue> {
/// Casts object -> MlirPassManager.
template <>
struct type_caster<MlirPassManager> {
NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"));
NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToPassManager(capsule.ptr());
Expand All @@ -281,7 +281,7 @@ struct type_caster<MlirPassManager> {
/// Casts object <-> MlirTypeID.
template <>
struct type_caster<MlirTypeID> {
NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"));
NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToTypeID(capsule.ptr());
Expand All @@ -303,7 +303,7 @@ struct type_caster<MlirTypeID> {
/// Casts object <-> MlirType.
template <>
struct type_caster<MlirType> {
NB_TYPE_CASTER(MlirType, const_name("MlirType"));
NB_TYPE_CASTER(MlirType, const_name("MlirType"))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
nanobind::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToType(capsule.ptr());
Expand Down
10 changes: 4 additions & 6 deletions mlir/include/mlir/Bindings/Python/PybindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,8 @@ class pure_subclass {
static_assert(!std::is_member_function_pointer<Func>::value,
"def_staticmethod(...) called with a non-static member "
"function pointer");
py::cpp_function cf(
std::forward<Func>(f), py::name(name), py::scope(thisClass),
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
py::cpp_function cf(std::forward<Func>(f), py::name(name),
py::scope(thisClass), extra...);
thisClass.attr(cf.name()) = py::staticmethod(cf);
return *this;
}
Expand All @@ -387,9 +386,8 @@ class pure_subclass {
static_assert(!std::is_member_function_pointer<Func>::value,
"def_classmethod(...) called with a non-static member "
"function pointer");
py::cpp_function cf(
std::forward<Func>(f), py::name(name), py::scope(thisClass),
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
py::cpp_function cf(std::forward<Func>(f), py::name(name),
py::scope(thisClass), extra...);
thisClass.attr(cf.name()) =
py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
return *this;
Expand Down
39 changes: 19 additions & 20 deletions mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H

#include "PybindUtils.h"
#include <optional>
#include <string>
#include <vector>

#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"

#include <optional>
#include <string>
#include <vector>

namespace mlir {
namespace python {

Expand Down Expand Up @@ -57,71 +56,71 @@ class PyGlobals {
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
pybind11::function pyFunc,
nanobind::callable pyFunc,
bool replace = false);

/// Adds a user-friendly type caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
bool replace = false);

/// Adds a user-friendly value caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerValueCaster(MlirTypeID mlirTypeID,
pybind11::function valueCaster,
nanobind::callable valueCaster,
bool replace = false);

/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerDialectImpl(const std::string &dialectNamespace,
pybind11::object pyClass);
nanobind::object pyClass);

/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
pybind11::object pyClass, bool replace = false);
nanobind::object pyClass, bool replace = false);

/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
std::optional<nanobind::callable>
lookupAttributeBuilder(const std::string &attributeKind);

/// Returns the custom type caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);

/// Returns the custom value caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);

/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
std::optional<nanobind::object>
lookupDialectClass(const std::string &dialectNamespace);

/// Looks up a registered operation class (deriving from OpView) by operation
/// name. Note that this may trigger a load of the dialect, which can
/// arbitrarily re-enter.
std::optional<pybind11::object>
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);

private:
static PyGlobals *instance;
/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.
llvm::StringMap<pybind11::object> dialectClassMap;
llvm::StringMap<nanobind::object> dialectClassMap;
/// Map of full operation name to external operation class object.
llvm::StringMap<pybind11::object> operationClassMap;
llvm::StringMap<nanobind::object> operationClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<pybind11::object> attributeBuilderMap;
llvm::StringMap<nanobind::callable> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
/// Map of MlirTypeID to custom value caster.
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
Expand Down
Loading

0 comments on commit b56d1ec

Please sign in to comment.