From aa9b9fb11870d4c19e79d4edfbe5a1bf776001d7 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 4 Dec 2024 02:44:33 +0000 Subject: [PATCH] mlir python] Port Python core code to nanobind. Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX. For a complicated Google-internal LLM model in JAX, this change improves the MLIR lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks. To a large extent, this is a mechanical change, for instance changing pybind11:: to nanobind::. Notes: * this PR needs https://github.com/wjakob/nanobind/pull/806 to land in nanobind first. Without that fix, importing the MLIR modules will fail. * this PR does not port the in-tree dialect extension modules. They can be ported in a future PR. * I removed the py::sibling() annotations from def_static and def_class in PybindAdapters.h. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here. * nanobind does not contain an exact equivalent of pybind11's buffer protocol support. It was not hard to add a nanobind implementation of a similar API. * nanobind is pickier about casting to std::vector, expecting that the input is a sequence of bool types, not truthy values. In a couple of places I added code to support truthy values during casting. * nanobind distinguishes bytes (nb::bytes) from strings (e.g., std::string). This required nb::bytes overloads in a few places. --- mlir/cmake/modules/MLIRDetectPythonEnv.cmake | 2 +- mlir/include/mlir/Bindings/Python/IRTypes.h | 2 +- .../mlir/Bindings/Python/PybindAdaptors.h | 10 +- mlir/lib/Bindings/Python/Globals.h | 39 +- mlir/lib/Bindings/Python/IRAffine.cpp | 265 ++-- mlir/lib/Bindings/Python/IRAttributes.cpp | 663 +++++--- mlir/lib/Bindings/Python/IRCore.cpp | 1412 +++++++++-------- mlir/lib/Bindings/Python/IRInterfaces.cpp | 171 +- mlir/lib/Bindings/Python/IRModule.cpp | 57 +- mlir/lib/Bindings/Python/IRModule.h | 332 ++-- mlir/lib/Bindings/Python/IRTypes.cpp | 200 +-- mlir/lib/Bindings/Python/MainModule.cpp | 56 +- .../Python/{PybindUtils.h => NanobindUtils.h} | 84 +- mlir/lib/Bindings/Python/Pass.cpp | 58 +- mlir/lib/Bindings/Python/Pass.h | 4 +- mlir/lib/Bindings/Python/Rewrite.cpp | 43 +- mlir/lib/Bindings/Python/Rewrite.h | 4 +- mlir/python/CMakeLists.txt | 3 +- mlir/python/requirements.txt | 2 +- mlir/test/python/ir/symbol_table.py | 3 +- utils/bazel/WORKSPACE | 6 +- .../llvm-project-overlay/mlir/BUILD.bazel | 15 +- 22 files changed, 1862 insertions(+), 1569 deletions(-) rename mlir/lib/Bindings/Python/{PybindUtils.h => NanobindUtils.h} (85%) diff --git a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake index c62ac7fa615ea6..d6bb65c64b8292 100644 --- a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake +++ b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake @@ -39,7 +39,7 @@ macro(mlir_configure_python_dev_packages) "extension = '${PYTHON_MODULE_EXTENSION}") mlir_detect_nanobind_install() - find_package(nanobind 2.2 CONFIG REQUIRED) + find_package(nanobind 2.4 CONFIG REQUIRED) message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}") message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', " "suffix = '${PYTHON_MODULE_SUFFIX}', " diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h index 9afad4c23b3f35..ba9642cf2c6a2d 100644 --- a/mlir/include/mlir/Bindings/Python/IRTypes.h +++ b/mlir/include/mlir/Bindings/Python/IRTypes.h @@ -9,7 +9,7 @@ #ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H #define MLIR_BINDINGS_PYTHON_IRTYPES_H -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace mlir { diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index c8233355d1d67b..edc69774be9227 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -374,9 +374,8 @@ class pure_subclass { static_assert(!std::is_member_function_pointer::value, "def_staticmethod(...) called with a non-static member " "function pointer"); - py::cpp_function cf( - std::forward(f), py::name(name), py::scope(thisClass), - py::sibling(py::getattr(thisClass, name, py::none())), extra...); + py::cpp_function cf(std::forward(f), py::name(name), + py::scope(thisClass), extra...); thisClass.attr(cf.name()) = py::staticmethod(cf); return *this; } @@ -387,9 +386,8 @@ class pure_subclass { static_assert(!std::is_member_function_pointer::value, "def_classmethod(...) called with a non-static member " "function pointer"); - py::cpp_function cf( - std::forward(f), py::name(name), py::scope(thisClass), - py::sibling(py::getattr(thisClass, name, py::none())), extra...); + py::cpp_function cf(std::forward(f), py::name(name), + py::scope(thisClass), extra...); thisClass.attr(cf.name()) = py::reinterpret_borrow(PyClassMethod_New(cf.ptr())); return *this; diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index a022067f5c7e57..0ec522d14f74bd 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -9,18 +9,17 @@ #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H #define MLIR_BINDINGS_PYTHON_GLOBALS_H -#include "PybindUtils.h" +#include +#include +#include +#include "NanobindUtils.h" #include "mlir-c/IR.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" -#include -#include -#include - namespace mlir { namespace python { @@ -57,55 +56,55 @@ class PyGlobals { /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerAttributeBuilder(const std::string &attributeKind, - pybind11::function pyFunc, + nanobind::callable pyFunc, bool replace = false); /// Adds a user-friendly type caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by /// implementation code. - void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster, + void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace = false); /// Adds a user-friendly value caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by /// implementation code. void registerValueCaster(MlirTypeID mlirTypeID, - pybind11::function valueCaster, + nanobind::callable valueCaster, bool replace = false); /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. void registerDialectImpl(const std::string &dialectNamespace, - pybind11::object pyClass); + nanobind::object pyClass); /// Adds a concrete implementation operation class. /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, - pybind11::object pyClass, bool replace = false); + nanobind::object pyClass, bool replace = false); /// Returns the custom Attribute builder for Attribute kind. - std::optional + std::optional lookupAttributeBuilder(const std::string &attributeKind); /// Returns the custom type caster for MlirTypeID mlirTypeID. - std::optional lookupTypeCaster(MlirTypeID mlirTypeID, + std::optional lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect); /// Returns the custom value caster for MlirTypeID mlirTypeID. - std::optional lookupValueCaster(MlirTypeID mlirTypeID, + std::optional lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect); /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. - std::optional + std::optional lookupDialectClass(const std::string &dialectNamespace); /// Looks up a registered operation class (deriving from OpView) by operation /// name. Note that this may trigger a load of the dialect, which can /// arbitrarily re-enter. - std::optional + std::optional lookupOperationClass(llvm::StringRef operationName); private: @@ -113,15 +112,15 @@ class PyGlobals { /// Module name prefixes to search under for dialect implementation modules. std::vector dialectSearchPrefixes; /// Map of dialect namespace to external dialect class object. - llvm::StringMap dialectClassMap; + llvm::StringMap dialectClassMap; /// Map of full operation name to external operation class object. - llvm::StringMap operationClassMap; + llvm::StringMap operationClassMap; /// Map of attribute ODS name to custom builder. - llvm::StringMap attributeBuilderMap; + llvm::StringMap attributeBuilderMap; /// Map of MlirTypeID to custom type caster. - llvm::DenseMap typeCasterMap; + llvm::DenseMap typeCasterMap; /// Map of MlirTypeID to custom value caster. - llvm::DenseMap valueCasterMap; + llvm::DenseMap valueCasterMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index b138e131e851ea..2db690309fab8c 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -6,20 +6,19 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include + #include #include -#include -#include -#include -#include +#include #include #include #include #include "IRModule.h" - -#include "PybindUtils.h" - +#include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" @@ -30,7 +29,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -46,23 +45,23 @@ static const char kDumpDocstring[] = /// Throws errors in case of failure, using "action" to describe what the caller /// was attempting to do. template -static void pyListToVector(const py::list &list, +static void pyListToVector(const nb::list &list, llvm::SmallVectorImpl &result, StringRef action) { - result.reserve(py::len(list)); - for (py::handle item : list) { + result.reserve(nb::len(list)); + for (nb::handle item : list) { try { - result.push_back(item.cast()); - } catch (py::cast_error &err) { + result.push_back(nb::cast(item)); + } catch (nb::cast_error &err) { std::string msg = (llvm::Twine("Invalid expression when ") + action + " (" + err.what() + ")") .str(); - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { + throw std::runtime_error(msg.c_str()); + } catch (std::runtime_error &err) { std::string msg = (llvm::Twine("Invalid expression (None?) when ") + action + " (" + err.what() + ")") .str(); - throw py::cast_error(msg); + throw std::runtime_error(msg.c_str()); } } } @@ -94,7 +93,7 @@ class PyConcreteAffineExpr : public BaseTy { // IsAFunctionTy isaFunction // const char *pyClassName // and redefine bindDerived. - using ClassTy = py::class_; + using ClassTy = nb::class_; using IsAFunctionTy = bool (*)(MlirAffineExpr); PyConcreteAffineExpr() = default; @@ -105,24 +104,25 @@ class PyConcreteAffineExpr : public BaseTy { static MlirAffineExpr castFrom(PyAffineExpr &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw py::value_error((Twine("Cannot cast affine expression to ") + + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast affine expression to ") + DerivedTy::pyClassName + " (from " + origRepr + ")") - .str()); + .str() + .c_str()); } return orig; } - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); - cls.def(py::init(), py::arg("expr")); + static void bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::arg("expr")); cls.def_static( "isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { return DerivedTy::isaFunction(otherAffineExpr); }, - py::arg("other")); + nb::arg("other")); DerivedTy::bindDerived(cls); } @@ -144,9 +144,9 @@ class PyAffineConstantExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none()); - c.def_property_readonly("value", [](PyAffineConstantExpr &self) { + c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("value", [](PyAffineConstantExpr &self) { return mlirAffineConstantExprGetValue(self); }); } @@ -164,9 +164,9 @@ class PyAffineDimExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineDimExpr &self) { + c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("position", [](PyAffineDimExpr &self) { return mlirAffineDimExprGetPosition(self); }); } @@ -184,9 +184,9 @@ class PyAffineSymbolExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { + c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("position", [](PyAffineSymbolExpr &self) { return mlirAffineSymbolExprGetPosition(self); }); } @@ -209,8 +209,8 @@ class PyAffineBinaryExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); - c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); + c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs); + c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs); } }; @@ -365,15 +365,14 @@ bool PyAffineExpr::operator==(const PyAffineExpr &other) const { return mlirAffineExprEqual(affineExpr, other.affineExpr); } -py::object PyAffineExpr::getCapsule() { - return py::reinterpret_steal( - mlirPythonAffineExprToCapsule(*this)); +nb::object PyAffineExpr::getCapsule() { + return nb::steal(mlirPythonAffineExprToCapsule(*this)); } -PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { +PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) { MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); if (mlirAffineExprIsNull(rawAffineExpr)) - throw py::error_already_set(); + throw nb::python_error(); return PyAffineExpr( PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), rawAffineExpr); @@ -424,14 +423,14 @@ bool PyAffineMap::operator==(const PyAffineMap &other) const { return mlirAffineMapEqual(affineMap, other.affineMap); } -py::object PyAffineMap::getCapsule() { - return py::reinterpret_steal(mlirPythonAffineMapToCapsule(*this)); +nb::object PyAffineMap::getCapsule() { + return nb::steal(mlirPythonAffineMapToCapsule(*this)); } -PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { +PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) { MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); if (mlirAffineMapIsNull(rawAffineMap)) - throw py::error_already_set(); + throw nb::python_error(); return PyAffineMap( PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), rawAffineMap); @@ -454,11 +453,10 @@ class PyIntegerSetConstraint { bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } - static void bind(py::module &m) { - py::class_(m, "IntegerSetConstraint", - py::module_local()) - .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) - .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); + static void bind(nb::module_ &m) { + nb::class_(m, "IntegerSetConstraint") + .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr) + .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq); } private: @@ -501,27 +499,25 @@ bool PyIntegerSet::operator==(const PyIntegerSet &other) const { return mlirIntegerSetEqual(integerSet, other.integerSet); } -py::object PyIntegerSet::getCapsule() { - return py::reinterpret_steal( - mlirPythonIntegerSetToCapsule(*this)); +nb::object PyIntegerSet::getCapsule() { + return nb::steal(mlirPythonIntegerSetToCapsule(*this)); } -PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { +PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) { MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); if (mlirIntegerSetIsNull(rawIntegerSet)) - throw py::error_already_set(); + throw nb::python_error(); return PyIntegerSet( PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), rawIntegerSet); } -void mlir::python::populateIRAffine(py::module &m) { +void mlir::python::populateIRAffine(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineExpr and derived classes. //---------------------------------------------------------------------------- - py::class_(m, "AffineExpr", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineExpr::getCapsule) + nb::class_(m, "AffineExpr") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) .def("__add__", &PyAffineAddExpr::get) .def("__add__", &PyAffineAddExpr::getRHSConstant) @@ -558,7 +554,7 @@ void mlir::python::populateIRAffine(py::module &m) { .def("__eq__", [](PyAffineExpr &self, PyAffineExpr &other) { return self == other; }) .def("__eq__", - [](PyAffineExpr &self, py::object &other) { return false; }) + [](PyAffineExpr &self, nb::object &other) { return false; }) .def("__str__", [](PyAffineExpr &self) { PyPrintAccumulator printAccum; @@ -579,7 +575,7 @@ void mlir::python::populateIRAffine(py::module &m) { [](PyAffineExpr &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_property_readonly( + .def_prop_ro( "context", [](PyAffineExpr &self) { return self.getContext().getObject(); }) .def("compose", @@ -632,16 +628,16 @@ void mlir::python::populateIRAffine(py::module &m) { .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, "Gets an affine expression containing the rounded-up result " "of dividing an expression by a constant.") - .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none(), + .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"), + nb::arg("context").none() = nb::none(), "Gets a constant affine expression with the given value.") .def_static( - "get_dim", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none(), + "get_dim", &PyAffineDimExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none(), "Gets an affine expression of a dimension at the given position.") .def_static( - "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none(), + "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none(), "Gets an affine expression of a symbol at the given position.") .def( "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, @@ -659,13 +655,12 @@ void mlir::python::populateIRAffine(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineMap. //---------------------------------------------------------------------------- - py::class_(m, "AffineMap", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineMap::getCapsule) + nb::class_(m, "AffineMap") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) .def("__eq__", [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) - .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) + .def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; }) .def("__str__", [](PyAffineMap &self) { PyPrintAccumulator printAccum; @@ -687,7 +682,7 @@ void mlir::python::populateIRAffine(py::module &m) { return static_cast(llvm::hash_value(self.get().ptr)); }) .def_static("compress_unused_symbols", - [](py::list affineMaps, DefaultingPyMlirContext context) { + [](nb::list affineMaps, DefaultingPyMlirContext context) { SmallVector maps; pyListToVector( affineMaps, maps, "attempting to create an AffineMap"); @@ -704,7 +699,7 @@ void mlir::python::populateIRAffine(py::module &m) { res.emplace_back(context->getRef(), m); return res; }) - .def_property_readonly( + .def_prop_ro( "context", [](PyAffineMap &self) { return self.getContext().getObject(); }, "Context that owns the Affine Map") @@ -713,7 +708,7 @@ void mlir::python::populateIRAffine(py::module &m) { kDumpDocstring) .def_static( "get", - [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, + [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs, DefaultingPyMlirContext context) { SmallVector affineExprs; pyListToVector( @@ -723,8 +718,8 @@ void mlir::python::populateIRAffine(py::module &m) { affineExprs.size(), affineExprs.data()); return PyAffineMap(context->getRef(), map); }, - py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), - py::arg("context") = py::none(), + nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"), + nb::arg("context").none() = nb::none(), "Gets a map with the given expressions as results.") .def_static( "get_constant", @@ -733,7 +728,7 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapConstantGet(context->get(), value); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an affine map with a single constant result") .def_static( "get_empty", @@ -741,7 +736,7 @@ void mlir::python::populateIRAffine(py::module &m) { MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("context") = py::none(), "Gets an empty affine map.") + nb::arg("context").none() = nb::none(), "Gets an empty affine map.") .def_static( "get_identity", [](intptr_t nDims, DefaultingPyMlirContext context) { @@ -749,7 +744,7 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapMultiDimIdentityGet(context->get(), nDims); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("n_dims"), py::arg("context") = py::none(), + nb::arg("n_dims"), nb::arg("context").none() = nb::none(), "Gets an identity map with the given number of dimensions.") .def_static( "get_minor_identity", @@ -759,8 +754,8 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("n_dims"), py::arg("n_results"), - py::arg("context") = py::none(), + nb::arg("n_dims"), nb::arg("n_results"), + nb::arg("context").none() = nb::none(), "Gets a minor identity map with the given number of dimensions and " "results.") .def_static( @@ -768,13 +763,13 @@ void mlir::python::populateIRAffine(py::module &m) { [](std::vector permutation, DefaultingPyMlirContext context) { if (!isPermutation(permutation)) - throw py::cast_error("Invalid permutation when attempting to " - "create an AffineMap"); + throw std::runtime_error("Invalid permutation when attempting to " + "create an AffineMap"); MlirAffineMap affineMap = mlirAffineMapPermutationGet( context->get(), permutation.size(), permutation.data()); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("permutation"), py::arg("context") = py::none(), + nb::arg("permutation"), nb::arg("context").none() = nb::none(), "Gets an affine map that permutes its inputs.") .def( "get_submap", @@ -782,33 +777,33 @@ void mlir::python::populateIRAffine(py::module &m) { intptr_t numResults = mlirAffineMapGetNumResults(self); for (intptr_t pos : resultPos) { if (pos < 0 || pos >= numResults) - throw py::value_error("result position out of bounds"); + throw nb::value_error("result position out of bounds"); } MlirAffineMap affineMap = mlirAffineMapGetSubMap( self, resultPos.size(), resultPos.data()); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("result_positions")) + nb::arg("result_positions")) .def( "get_major_submap", [](PyAffineMap &self, intptr_t nResults) { if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); + throw nb::value_error("number of results out of bounds"); MlirAffineMap affineMap = mlirAffineMapGetMajorSubMap(self, nResults); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("n_results")) + nb::arg("n_results")) .def( "get_minor_submap", [](PyAffineMap &self, intptr_t nResults) { if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); + throw nb::value_error("number of results out of bounds"); MlirAffineMap affineMap = mlirAffineMapGetMinorSubMap(self, nResults); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("n_results")) + nb::arg("n_results")) .def( "replace", [](PyAffineMap &self, PyAffineExpr &expression, @@ -818,39 +813,37 @@ void mlir::python::populateIRAffine(py::module &m) { self, expression, replacement, numResultDims, numResultSyms); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"), - py::arg("n_result_syms")) - .def_property_readonly( + nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"), + nb::arg("n_result_syms")) + .def_prop_ro( "is_permutation", [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) - .def_property_readonly("is_projected_permutation", - [](PyAffineMap &self) { - return mlirAffineMapIsProjectedPermutation(self); - }) - .def_property_readonly( + .def_prop_ro("is_projected_permutation", + [](PyAffineMap &self) { + return mlirAffineMapIsProjectedPermutation(self); + }) + .def_prop_ro( "n_dims", [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) - .def_property_readonly( + .def_prop_ro( "n_inputs", [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) - .def_property_readonly( + .def_prop_ro( "n_symbols", [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) - .def_property_readonly("results", [](PyAffineMap &self) { - return PyAffineMapExprList(self); - }); + .def_prop_ro("results", + [](PyAffineMap &self) { return PyAffineMapExprList(self); }); PyAffineMapExprList::bind(m); //---------------------------------------------------------------------------- // Mapping of PyIntegerSet. //---------------------------------------------------------------------------- - py::class_(m, "IntegerSet", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyIntegerSet::getCapsule) + nb::class_(m, "IntegerSet") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) .def("__eq__", [](PyIntegerSet &self, PyIntegerSet &other) { return self == other; }) - .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) + .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; }) .def("__str__", [](PyIntegerSet &self) { PyPrintAccumulator printAccum; @@ -871,7 +864,7 @@ void mlir::python::populateIRAffine(py::module &m) { [](PyIntegerSet &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_property_readonly( + .def_prop_ro( "context", [](PyIntegerSet &self) { return self.getContext().getObject(); }) .def( @@ -879,14 +872,14 @@ void mlir::python::populateIRAffine(py::module &m) { kDumpDocstring) .def_static( "get", - [](intptr_t numDims, intptr_t numSymbols, py::list exprs, + [](intptr_t numDims, intptr_t numSymbols, nb::list exprs, std::vector eqFlags, DefaultingPyMlirContext context) { if (exprs.size() != eqFlags.size()) - throw py::value_error( + throw nb::value_error( "Expected the number of constraints to match " "that of equality flags"); - if (exprs.empty()) - throw py::value_error("Expected non-empty list of constraints"); + if (exprs.size() == 0) + throw nb::value_error("Expected non-empty list of constraints"); // Copy over to a SmallVector because std::vector has a // specialization for booleans that packs data and does not @@ -901,8 +894,8 @@ void mlir::python::populateIRAffine(py::module &m) { affineExprs.data(), flags.data()); return PyIntegerSet(context->getRef(), set); }, - py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), - py::arg("eq_flags"), py::arg("context") = py::none()) + nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"), + nb::arg("eq_flags"), nb::arg("context").none() = nb::none()) .def_static( "get_empty", [](intptr_t numDims, intptr_t numSymbols, @@ -911,20 +904,20 @@ void mlir::python::populateIRAffine(py::module &m) { mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); return PyIntegerSet(context->getRef(), set); }, - py::arg("num_dims"), py::arg("num_symbols"), - py::arg("context") = py::none()) + nb::arg("num_dims"), nb::arg("num_symbols"), + nb::arg("context").none() = nb::none()) .def( "get_replaced", - [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, + [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs, intptr_t numResultDims, intptr_t numResultSymbols) { if (static_cast(dimExprs.size()) != mlirIntegerSetGetNumDims(self)) - throw py::value_error( + throw nb::value_error( "Expected the number of dimension replacement expressions " "to match that of dimensions"); if (static_cast(symbolExprs.size()) != mlirIntegerSetGetNumSymbols(self)) - throw py::value_error( + throw nb::value_error( "Expected the number of symbol replacement expressions " "to match that of symbols"); @@ -940,30 +933,30 @@ void mlir::python::populateIRAffine(py::module &m) { numResultDims, numResultSymbols); return PyIntegerSet(self.getContext(), set); }, - py::arg("dim_exprs"), py::arg("symbol_exprs"), - py::arg("num_result_dims"), py::arg("num_result_symbols")) - .def_property_readonly("is_canonical_empty", - [](PyIntegerSet &self) { - return mlirIntegerSetIsCanonicalEmpty(self); - }) - .def_property_readonly( + nb::arg("dim_exprs"), nb::arg("symbol_exprs"), + nb::arg("num_result_dims"), nb::arg("num_result_symbols")) + .def_prop_ro("is_canonical_empty", + [](PyIntegerSet &self) { + return mlirIntegerSetIsCanonicalEmpty(self); + }) + .def_prop_ro( "n_dims", [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) - .def_property_readonly( + .def_prop_ro( "n_symbols", [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) - .def_property_readonly( + .def_prop_ro( "n_inputs", [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) - .def_property_readonly("n_equalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumEqualities(self); - }) - .def_property_readonly("n_inequalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumInequalities(self); - }) - .def_property_readonly("constraints", [](PyIntegerSet &self) { + .def_prop_ro("n_equalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumEqualities(self); + }) + .def_prop_ro("n_inequalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumInequalities(self); + }) + .def_prop_ro("constraints", [](PyIntegerSet &self) { return PyIntegerSetConstraintList(self); }); PyIntegerSetConstraint::bind(m); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index cc9532f4e33b2c..c85c4e286fbb61 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -6,23 +6,29 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include +#include +#include +#include + +#include #include +#include #include #include #include "IRModule.h" - -#include "PybindUtils.h" -#include - -#include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/raw_ostream.h" - +#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/raw_ostream.h" -namespace py = pybind11; +namespace nb = nanobind; +using namespace nanobind::literals; using namespace mlir; using namespace mlir::python; @@ -123,10 +129,108 @@ subsequent processing. namespace { +struct nb_buffer_info { + void *ptr = nullptr; + ssize_t itemsize = 0; + ssize_t size = 0; + const char *format = nullptr; + ssize_t ndim = 0; + SmallVector shape; + SmallVector strides; + bool readonly = false; + + nb_buffer_info() = default; + + nb_buffer_info(void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, + SmallVector shape_in, + SmallVector strides_in, bool readonly = false) + : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)), + readonly(readonly) { + size = 1; + for (ssize_t i = 0; i < ndim; ++i) { + size *= shape[i]; + } + } + + explicit nb_buffer_info(Py_buffer *view) + : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, + {view->shape, view->shape + view->ndim}, + // TODO(phawkins): check for null strides + {view->strides, view->strides + view->ndim}, + view->readonly != 0) {} +}; + +class nb_buffer : public nb::object { + NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); + + nb_buffer_info request() const { + int flags = PyBUF_STRIDES | PyBUF_FORMAT; + auto *view = new Py_buffer(); + if (PyObject_GetBuffer(ptr(), view, flags) != 0) { + delete view; + throw nb::python_error(); + } + return nb_buffer_info(view); + } +}; + +template +struct nb_format_descriptor {}; + +template <> +struct nb_format_descriptor { + static const char *format() { return "?"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "b"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "B"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "h"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "H"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "i"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "I"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "l"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "L"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "f"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "d"; } +}; + static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +static MlirStringRef toMlirStringRef(const nb::bytes &s) { + return mlirStringRefCreate(static_cast(s.data()), s.size()); +} + class PyAffineMapAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; @@ -142,9 +246,9 @@ class PyAffineMapAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); return PyAffineMapAttribute(affineMap.getContext(), attr); }, - py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); - c.def_property_readonly("value", mlirAffineMapAttrGetValue, - "Returns the value of the AffineMap attribute"); + nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); + c.def_prop_ro("value", mlirAffineMapAttrGetValue, + "Returns the value of the AffineMap attribute"); } }; @@ -164,25 +268,24 @@ class PyIntegerSetAttribute MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); return PyIntegerSetAttribute(integerSet.getContext(), attr); }, - py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); + nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); } }; template -static T pyTryCast(py::handle object) { +static T pyTryCast(nb::handle object) { try { - return object.cast(); - } catch (py::cast_error &err) { - std::string msg = - std::string( - "Invalid attribute when attempting to create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { + return nb::cast(object); + } catch (nb::cast_error &err) { + std::string msg = std::string("Invalid attribute when attempting to " + "create an ArrayAttribute (") + + err.what() + ")"; + throw std::runtime_error(msg.c_str()); + } catch (std::runtime_error &err) { std::string msg = std::string("Invalid attribute (None?) when attempting " "to create an ArrayAttribute (") + err.what() + ")"; - throw py::cast_error(msg); + throw std::runtime_error(msg.c_str()); } } @@ -205,14 +308,13 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { EltTy dunderNext() { // Throw if the index has reached the end. if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) - throw py::stop_iteration(); + throw nb::stop_iteration(); return DerivedT::getElement(attr.get(), nextIndex++); } /// Bind the iterator class. - static void bind(py::module &m) { - py::class_(m, DerivedT::pyIteratorName, - py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, DerivedT::pyIteratorName) .def("__iter__", &PyDenseArrayIterator::dunderIter) .def("__next__", &PyDenseArrayIterator::dunderNext); } @@ -230,17 +332,35 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { /// Bind the attribute class. static void bindDerived(typename PyConcreteAttribute::ClassTy &c) { // Bind the constructor. - c.def_static( - "get", - [](const std::vector &values, DefaultingPyMlirContext ctx) { - return getAttribute(values, ctx->getRef()); - }, - py::arg("values"), py::arg("context") = py::none(), - "Gets a uniqued dense array attribute"); + if constexpr (std::is_same_v) { + c.def_static( + "get", + [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { + std::vector values; + for (nb::handle py_value : py_values) { + int is_true = PyObject_IsTrue(py_value.ptr()); + if (is_true < 0) { + throw nb::python_error(); + } + values.push_back(is_true); + } + return getAttribute(values, ctx->getRef()); + }, + nb::arg("values"), nb::arg("context").none() = nb::none(), + "Gets a uniqued dense array attribute"); + } else { + c.def_static( + "get", + [](const std::vector &values, DefaultingPyMlirContext ctx) { + return getAttribute(values, ctx->getRef()); + }, + nb::arg("values"), nb::arg("context").none() = nb::none(), + "Gets a uniqued dense array attribute"); + } // Bind the array methods. c.def("__getitem__", [](DerivedT &arr, intptr_t i) { if (i >= mlirDenseArrayGetNumElements(arr)) - throw py::index_error("DenseArray index out of range"); + throw nb::index_error("DenseArray index out of range"); return arr.getItem(i); }); c.def("__len__", [](const DerivedT &arr) { @@ -248,13 +368,13 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { }); c.def("__iter__", [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); - c.def("__add__", [](DerivedT &arr, const py::list &extras) { + c.def("__add__", [](DerivedT &arr, const nb::list &extras) { std::vector values; intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); - values.reserve(numOldElements + py::len(extras)); + values.reserve(numOldElements + nb::len(extras)); for (intptr_t i = 0; i < numOldElements; ++i) values.push_back(arr.getItem(i)); - for (py::handle attr : extras) + for (nb::handle attr : extras) values.push_back(pyTryCast(attr)); return getAttribute(values, arr.getContext()); }); @@ -358,13 +478,12 @@ class PyArrayAttribute : public PyConcreteAttribute { MlirAttribute dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) - throw py::stop_iteration(); + throw nb::stop_iteration(); return mlirArrayAttrGetElement(attr.get(), nextIndex++); } - static void bind(py::module &m) { - py::class_(m, "ArrayAttributeIterator", - py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "ArrayAttributeIterator") .def("__iter__", &PyArrayAttributeIterator::dunderIter) .def("__next__", &PyArrayAttributeIterator::dunderNext); } @@ -381,9 +500,9 @@ class PyArrayAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](py::list attributes, DefaultingPyMlirContext context) { + [](nb::list attributes, DefaultingPyMlirContext context) { SmallVector mlirAttributes; - mlirAttributes.reserve(py::len(attributes)); + mlirAttributes.reserve(nb::len(attributes)); for (auto attribute : attributes) { mlirAttributes.push_back(pyTryCast(attribute)); } @@ -391,12 +510,12 @@ class PyArrayAttribute : public PyConcreteAttribute { context->get(), mlirAttributes.size(), mlirAttributes.data()); return PyArrayAttribute(context->getRef(), attr); }, - py::arg("attributes"), py::arg("context") = py::none(), + nb::arg("attributes"), nb::arg("context").none() = nb::none(), "Gets a uniqued Array attribute"); c.def("__getitem__", [](PyArrayAttribute &arr, intptr_t i) { if (i >= mlirArrayAttrGetNumElements(arr)) - throw py::index_error("ArrayAttribute index out of range"); + throw nb::index_error("ArrayAttribute index out of range"); return arr.getItem(i); }) .def("__len__", @@ -406,13 +525,13 @@ class PyArrayAttribute : public PyConcreteAttribute { .def("__iter__", [](const PyArrayAttribute &arr) { return PyArrayAttributeIterator(arr); }); - c.def("__add__", [](PyArrayAttribute arr, py::list extras) { + c.def("__add__", [](PyArrayAttribute arr, nb::list extras) { std::vector attributes; intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); - attributes.reserve(numOldElements + py::len(extras)); + attributes.reserve(numOldElements + nb::len(extras)); for (intptr_t i = 0; i < numOldElements; ++i) attributes.push_back(arr.getItem(i)); - for (py::handle attr : extras) + for (nb::handle attr : extras) attributes.push_back(pyTryCast(attr)); MlirAttribute arrayAttr = mlirArrayAttrGet( arr.getContext()->get(), attributes.size(), attributes.data()); @@ -440,7 +559,7 @@ class PyFloatAttribute : public PyConcreteAttribute { throw MLIRError("Invalid attribute", errors.take()); return PyFloatAttribute(type.getContext(), attr); }, - py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), + nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(), "Gets an uniqued float point attribute associated to a type"); c.def_static( "get_f32", @@ -449,7 +568,7 @@ class PyFloatAttribute : public PyConcreteAttribute { context->get(), mlirF32TypeGet(context->get()), value); return PyFloatAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an uniqued float point attribute associated to a f32 type"); c.def_static( "get_f64", @@ -458,10 +577,10 @@ class PyFloatAttribute : public PyConcreteAttribute { context->get(), mlirF64TypeGet(context->get()), value); return PyFloatAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an uniqued float point attribute associated to a f64 type"); - c.def_property_readonly("value", mlirFloatAttrGetValueDouble, - "Returns the value of the float attribute"); + c.def_prop_ro("value", mlirFloatAttrGetValueDouble, + "Returns the value of the float attribute"); c.def("__float__", mlirFloatAttrGetValueDouble, "Converts the value of the float attribute to a Python float"); } @@ -481,20 +600,20 @@ class PyIntegerAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirIntegerAttrGet(type, value); return PyIntegerAttribute(type.getContext(), attr); }, - py::arg("type"), py::arg("value"), + nb::arg("type"), nb::arg("value"), "Gets an uniqued integer attribute associated to a type"); - c.def_property_readonly("value", toPyInt, - "Returns the value of the integer attribute"); + c.def_prop_ro("value", toPyInt, + "Returns the value of the integer attribute"); c.def("__int__", toPyInt, "Converts the value of the integer attribute to a Python int"); - c.def_property_readonly_static("static_typeid", - [](py::object & /*class*/) -> MlirTypeID { - return mlirIntegerAttrGetTypeID(); - }); + c.def_prop_ro_static("static_typeid", + [](nb::object & /*class*/) -> MlirTypeID { + return mlirIntegerAttrGetTypeID(); + }); } private: - static py::int_ toPyInt(PyIntegerAttribute &self) { + static int64_t toPyInt(PyIntegerAttribute &self) { MlirType type = mlirAttributeGetType(self); if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) return mlirIntegerAttrGetValueInt(self); @@ -518,10 +637,10 @@ class PyBoolAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirBoolAttrGet(context->get(), value); return PyBoolAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an uniqued bool attribute"); - c.def_property_readonly("value", mlirBoolAttrGetValue, - "Returns the value of the bool attribute"); + c.def_prop_ro("value", mlirBoolAttrGetValue, + "Returns the value of the bool attribute"); c.def("__bool__", mlirBoolAttrGetValue, "Converts the value of the bool attribute to a Python bool"); } @@ -555,9 +674,9 @@ class PySymbolRefAttribute : public PyConcreteAttribute { DefaultingPyMlirContext context) { return PySymbolRefAttribute::fromList(symbols, context.resolve()); }, - py::arg("symbols"), py::arg("context") = py::none(), + nb::arg("symbols"), nb::arg("context").none() = nb::none(), "Gets a uniqued SymbolRef attribute from a list of symbol names"); - c.def_property_readonly( + c.def_prop_ro( "value", [](PySymbolRefAttribute &self) { std::vector symbols = { @@ -589,13 +708,13 @@ class PyFlatSymbolRefAttribute mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); return PyFlatSymbolRefAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets a uniqued FlatSymbolRef attribute"); - c.def_property_readonly( + c.def_prop_ro( "value", [](PyFlatSymbolRefAttribute &self) { MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the value of the FlatSymbolRef attribute as a string"); } @@ -612,29 +731,29 @@ class PyOpaqueAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string dialectNamespace, py::buffer buffer, PyType &type, + [](std::string dialectNamespace, nb_buffer buffer, PyType &type, DefaultingPyMlirContext context) { - const py::buffer_info bufferInfo = buffer.request(); + const nb_buffer_info bufferInfo = buffer.request(); intptr_t bufferSize = bufferInfo.size; MlirAttribute attr = mlirOpaqueAttrGet( context->get(), toMlirStringRef(dialectNamespace), bufferSize, static_cast(bufferInfo.ptr), type); return PyOpaqueAttribute(context->getRef(), attr); }, - py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), - py::arg("context") = py::none(), "Gets an Opaque attribute."); - c.def_property_readonly( + nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"), + nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); + c.def_prop_ro( "dialect_namespace", [](PyOpaqueAttribute &self) { MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the dialect namespace for the Opaque attribute as a string"); - c.def_property_readonly( + c.def_prop_ro( "data", [](PyOpaqueAttribute &self) { MlirStringRef stringRef = mlirOpaqueAttrGetData(self); - return py::bytes(stringRef.data, stringRef.length); + return nb::bytes(stringRef.data, stringRef.length); }, "Returns the data for the Opaqued attributes as `bytes`"); } @@ -656,7 +775,16 @@ class PyStringAttribute : public PyConcreteAttribute { mlirStringAttrGet(context->get(), toMlirStringRef(value)); return PyStringAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get", + [](nb::bytes value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets a uniqued string attribute"); c.def_static( "get_typed", @@ -665,20 +793,20 @@ class PyStringAttribute : public PyConcreteAttribute { mlirStringAttrTypedGet(type, toMlirStringRef(value)); return PyStringAttribute(type.getContext(), attr); }, - py::arg("type"), py::arg("value"), + nb::arg("type"), nb::arg("value"), "Gets a uniqued string attribute associated to a type"); - c.def_property_readonly( + c.def_prop_ro( "value", [](PyStringAttribute &self) { MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the value of the string attribute"); - c.def_property_readonly( + c.def_prop_ro( "value_bytes", [](PyStringAttribute &self) { MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::bytes(stringRef.data, stringRef.length); + return nb::bytes(stringRef.data, stringRef.length); }, "Returns the value of the string attribute as `bytes`"); } @@ -693,12 +821,11 @@ class PyDenseElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseElementsAttribute - getFromList(py::list attributes, std::optional explicitType, + getFromList(nb::list attributes, std::optional explicitType, DefaultingPyMlirContext contextWrapper) { - - const size_t numAttributes = py::len(attributes); + const size_t numAttributes = nb::len(attributes); if (numAttributes == 0) - throw py::value_error("Attributes list must be non-empty."); + throw nb::value_error("Attributes list must be non-empty."); MlirType shapedType; if (explicitType) { @@ -708,8 +835,8 @@ class PyDenseElementsAttribute std::string message; llvm::raw_string_ostream os(message); os << "Expected a static ShapedType for the shaped_type parameter: " - << py::repr(py::cast(*explicitType)); - throw py::value_error(message); + << nb::cast(nb::repr(nb::cast(*explicitType))); + throw nb::value_error(message.c_str()); } shapedType = *explicitType; } else { @@ -722,7 +849,7 @@ class PyDenseElementsAttribute SmallVector mlirAttributes; mlirAttributes.reserve(numAttributes); - for (const py::handle &attribute : attributes) { + for (const nb::handle &attribute : attributes) { MlirAttribute mlirAttribute = pyTryCast(attribute); MlirType attrType = mlirAttributeGetType(mlirAttribute); mlirAttributes.push_back(mlirAttribute); @@ -731,9 +858,11 @@ class PyDenseElementsAttribute std::string message; llvm::raw_string_ostream os(message); os << "All attributes must be of the same type and match " - << "the type parameter: expected=" << py::repr(py::cast(shapedType)) - << ", but got=" << py::repr(py::cast(attrType)); - throw py::value_error(message); + << "the type parameter: expected=" + << nb::cast(nb::repr(nb::cast(shapedType))) + << ", but got=" + << nb::cast(nb::repr(nb::cast(attrType))); + throw nb::value_error(message.c_str()); } } @@ -744,7 +873,7 @@ class PyDenseElementsAttribute } static PyDenseElementsAttribute - getFromBuffer(py::buffer array, bool signless, + getFromBuffer(nb_buffer array, bool signless, std::optional explicitType, std::optional> explicitShape, DefaultingPyMlirContext contextWrapper) { @@ -755,7 +884,7 @@ class PyDenseElementsAttribute } Py_buffer view; if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { - throw py::error_already_set(); + throw nb::python_error(); } auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); @@ -778,25 +907,29 @@ class PyDenseElementsAttribute if (!mlirAttributeIsAInteger(elementAttr) && !mlirAttributeIsAFloat(elementAttr)) { std::string message = "Illegal element type for DenseElementsAttr: "; - message.append(py::repr(py::cast(elementAttr))); - throw py::value_error(message); + message.append( + nb::cast(nb::repr(nb::cast(elementAttr)))); + throw nb::value_error(message.c_str()); } if (!mlirTypeIsAShaped(shapedType) || !mlirShapedTypeHasStaticShape(shapedType)) { std::string message = "Expected a static ShapedType for the shaped_type parameter: "; - message.append(py::repr(py::cast(shapedType))); - throw py::value_error(message); + message.append( + nb::cast(nb::repr(nb::cast(shapedType)))); + throw nb::value_error(message.c_str()); } MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); MlirType attrType = mlirAttributeGetType(elementAttr); if (!mlirTypeEqual(shapedElementType, attrType)) { std::string message = "Shaped element type and attribute type must be equal: shaped="; - message.append(py::repr(py::cast(shapedType))); + message.append( + nb::cast(nb::repr(nb::cast(shapedType)))); message.append(", element="); - message.append(py::repr(py::cast(elementAttr))); - throw py::value_error(message); + message.append( + nb::cast(nb::repr(nb::cast(elementAttr)))); + throw nb::value_error(message.c_str()); } MlirAttribute elements = @@ -806,7 +939,7 @@ class PyDenseElementsAttribute intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } - py::buffer_info accessBuffer() { + nb_buffer_info accessBuffer() { MlirType shapedType = mlirAttributeGetType(*this); MlirType elementType = mlirShapedTypeGetElementType(shapedType); std::string format; @@ -889,32 +1022,36 @@ class PyDenseElementsAttribute static void bindDerived(ClassTy &c) { c.def("__len__", &PyDenseElementsAttribute::dunderLen) .def_static("get", PyDenseElementsAttribute::getFromBuffer, - py::arg("array"), py::arg("signless") = true, - py::arg("type") = py::none(), py::arg("shape") = py::none(), - py::arg("context") = py::none(), + nb::arg("array"), nb::arg("signless") = true, + nb::arg("type").none() = nb::none(), + nb::arg("shape").none() = nb::none(), + nb::arg("context").none() = nb::none(), kDenseElementsAttrGetDocstring) .def_static("get", PyDenseElementsAttribute::getFromList, - py::arg("attrs"), py::arg("type") = py::none(), - py::arg("context") = py::none(), + nb::arg("attrs"), nb::arg("type").none() = nb::none(), + nb::arg("context").none() = nb::none(), kDenseElementsAttrGetFromListDocstring) .def_static("get_splat", PyDenseElementsAttribute::getSplat, - py::arg("shaped_type"), py::arg("element_attr"), + nb::arg("shaped_type"), nb::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") - .def_property_readonly("is_splat", - [](PyDenseElementsAttribute &self) -> bool { - return mlirDenseElementsAttrIsSplat(self); - }) - .def("get_splat_value", - [](PyDenseElementsAttribute &self) { - if (!mlirDenseElementsAttrIsSplat(self)) - throw py::value_error( - "get_splat_value called on a non-splat attribute"); - return mlirDenseElementsAttrGetSplatValue(self); - }) - .def_buffer(&PyDenseElementsAttribute::accessBuffer); + .def_prop_ro("is_splat", + [](PyDenseElementsAttribute &self) -> bool { + return mlirDenseElementsAttrIsSplat(self); + }) + .def("get_splat_value", [](PyDenseElementsAttribute &self) { + if (!mlirDenseElementsAttrIsSplat(self)) + throw nb::value_error( + "get_splat_value called on a non-splat attribute"); + return mlirDenseElementsAttrGetSplatValue(self); + }); } + static PyType_Slot slots[]; + private: + static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); + static void bf_releasebuffer(PyObject *, Py_buffer *buffer); + static bool isUnsignedIntegerFormat(std::string_view format) { if (format.empty()) return false; @@ -1039,27 +1176,27 @@ class PyDenseElementsAttribute return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); } - // There is a complication for boolean numpy arrays, as numpy represents them - // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans - // per byte. + // There is a complication for boolean numpy arrays, as numpy represents + // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 + // booleans per byte. static MlirAttribute getBitpackedAttributeFromBooleanBuffer( Py_buffer &view, std::optional> explicitShape, MlirContext &context) { if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian systems - // we will throw - throw py::type_error("Constructing a bit-packed MLIR attribute is " + // Given we have no good way of testing the behavior on big-endian + // systems we will throw + throw nb::type_error("Constructing a bit-packed MLIR attribute is " "unsupported on big-endian systems"); } + nb::ndarray, nb::c_contig> unpackedArray( + /*data=*/static_cast(view.buf), + /*shape=*/{static_cast(view.len)}); - py::array_t unpackedArray(view.len, - static_cast(view.buf)); - - py::module numpy = py::module::import("numpy"); - py::object packbitsFunc = numpy.attr("packbits"); - py::object packedBooleans = - packbitsFunc(unpackedArray, "bitorder"_a = "little"); - py::buffer_info pythonBuffer = packedBooleans.cast().request(); + nb::module_ numpy = nb::module_::import_("numpy"); + nb::object packbitsFunc = numpy.attr("packbits"); + nb::object packedBooleans = + packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); + nb_buffer_info pythonBuffer = nb::cast(packedBooleans).request(); MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); @@ -1073,11 +1210,11 @@ class PyDenseElementsAttribute // This does the opposite transformation of // `getBitpackedAttributeFromBooleanBuffer` - py::buffer_info getBooleanBufferFromBitpackedAttribute() { + nb_buffer_info getBooleanBufferFromBitpackedAttribute() { if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian systems - // we will throw - throw py::type_error("Constructing a numpy array from a MLIR attribute " + // Given we have no good way of testing the behavior on big-endian + // systems we will throw + throw nb::type_error("Constructing a numpy array from a MLIR attribute " "is unsupported on big-endian systems"); } @@ -1085,21 +1222,24 @@ class PyDenseElementsAttribute int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); uint8_t *bitpackedData = static_cast( const_cast(mlirDenseElementsAttrGetRawData(*this))); - py::array_t packedArray(numBitpackedBytes, bitpackedData); + nb::ndarray, nb::c_contig> packedArray( + /*data=*/bitpackedData, + /*shape=*/{static_cast(numBitpackedBytes)}); - py::module numpy = py::module::import("numpy"); - py::object unpackbitsFunc = numpy.attr("unpackbits"); - py::object equalFunc = numpy.attr("equal"); - py::object reshapeFunc = numpy.attr("reshape"); - py::array unpackedBooleans = - unpackbitsFunc(packedArray, "bitorder"_a = "little"); + nb::module_ numpy = nb::module_::import_("numpy"); + nb::object unpackbitsFunc = numpy.attr("unpackbits"); + nb::object equalFunc = numpy.attr("equal"); + nb::object reshapeFunc = numpy.attr("reshape"); + nb::object unpackedBooleans = + unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little"); // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. // We need to: // 1. Slice away the padded bits // 2. Make the boolean array have the correct shape // 3. Convert the array to a boolean array - unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)]; + unpackedBooleans = unpackedBooleans[nb::slice( + nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; unpackedBooleans = equalFunc(unpackedBooleans, 1); MlirType shapedType = mlirAttributeGetType(*this); @@ -1110,15 +1250,15 @@ class PyDenseElementsAttribute } unpackedBooleans = reshapeFunc(unpackedBooleans, shape); - // Make sure the returned py::buffer_view claims ownership of the data in + // Make sure the returned nb::buffer_view claims ownership of the data in // `pythonBuffer` so it remains valid when Python reads it - py::buffer pythonBuffer = unpackedBooleans.cast(); + nb_buffer pythonBuffer = nb::cast(unpackedBooleans); return pythonBuffer.request(); } template - py::buffer_info bufferInfo(MlirType shapedType, - const char *explicitFormat = nullptr) { + nb_buffer_info bufferInfo(MlirType shapedType, + const char *explicitFormat = nullptr) { intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. // Buffer is configured for read-only access below. @@ -1142,19 +1282,69 @@ class PyDenseElementsAttribute } strides.push_back(sizeof(Type)); } - std::string format; + const char *format; if (explicitFormat) { format = explicitFormat; } else { - format = py::format_descriptor::format(); + format = nb_format_descriptor::format(); } - return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, - /*readonly=*/true); + return nb_buffer_info(data, sizeof(Type), format, rank, std::move(shape), + std::move(strides), + /*readonly=*/true); } }; // namespace -/// Refinement of the PyDenseElementsAttribute for attributes containing integer -/// (and boolean) values. Supports element access. +PyType_Slot PyDenseElementsAttribute::slots[] = { + {Py_bf_getbuffer, + reinterpret_cast(PyDenseElementsAttribute::bf_getbuffer)}, + {Py_bf_releasebuffer, + reinterpret_cast(PyDenseElementsAttribute::bf_releasebuffer)}, + {0, nullptr}, +}; + +/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, + Py_buffer *view, + int flags) { + view->obj = nullptr; + nb_buffer_info info; + try { + auto *attr = nb::cast(nb::handle(obj)); + info = attr->accessBuffer(); + } catch (nb::python_error &e) { + e.restore(); + nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer"); + return -1; + } + view->obj = obj; + view->ndim = 1; + view->buf = info.ptr; + view->itemsize = info.itemsize; + view->len = info.itemsize; + for (auto s : info.shape) { + view->len *= s; + } + view->readonly = info.readonly; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(info.format); + } + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + view->ndim = static_cast(info.ndim); + view->strides = info.strides.data(); + view->shape = info.shape.data(); + } + view->suboffsets = nullptr; + view->internal = new nb_buffer_info(std::move(info)); + Py_INCREF(obj); + return 0; +} + +/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, + Py_buffer *view) { + delete reinterpret_cast(view->internal); +} + +/// Refinement of the PyDenseElementsAttribute for attributes containing +/// integer (and boolean) values. Supports element access. class PyDenseIntElementsAttribute : public PyConcreteAttribute { @@ -1163,11 +1353,11 @@ class PyDenseIntElementsAttribute static constexpr const char *pyClassName = "DenseIntElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; - /// Returns the element at the given linear position. Asserts if the index is - /// out of range. - py::int_ dunderGetItem(intptr_t pos) { + /// Returns the element at the given linear position. Asserts if the index + /// is out of range. + nb::object dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw py::index_error("attempt to access out of bounds element"); + throw nb::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); @@ -1175,7 +1365,7 @@ class PyDenseIntElementsAttribute assert(mlirTypeIsAInteger(type) && "expected integer element type in dense int elements attribute"); // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::int_ is implicitly constructible + // elemental type of the attribute. nb::int_ is implicitly constructible // from any C++ integral type and handles bitwidth correctly. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. @@ -1183,38 +1373,38 @@ class PyDenseIntElementsAttribute bool isUnsigned = mlirIntegerTypeIsUnsigned(type); if (isUnsigned) { if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); + return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); } if (width == 8) { - return mlirDenseElementsAttrGetUInt8Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); } if (width == 16) { - return mlirDenseElementsAttrGetUInt16Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); } if (width == 32) { - return mlirDenseElementsAttrGetUInt32Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); } if (width == 64) { - return mlirDenseElementsAttrGetUInt64Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); } } else { if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); + return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); } if (width == 8) { - return mlirDenseElementsAttrGetInt8Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); } if (width == 16) { - return mlirDenseElementsAttrGetInt16Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); } if (width == 32) { - return mlirDenseElementsAttrGetInt32Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); } if (width == 64) { - return mlirDenseElementsAttrGetInt64Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); } } - throw py::type_error("Unsupported integer type"); + throw nb::type_error("Unsupported integer type"); } static void bindDerived(ClassTy &c) { @@ -1231,7 +1421,7 @@ class PyDenseResourceElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseResourceElementsAttribute - getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type, + getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, std::optional alignment, bool isMutable, DefaultingPyMlirContext contextWrapper) { if (!mlirTypeIsAShaped(type)) { @@ -1244,7 +1434,7 @@ class PyDenseResourceElementsAttribute int flags = PyBUF_STRIDES; std::unique_ptr view = std::make_unique(); if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { - throw py::error_already_set(); + throw nb::python_error(); } // This scope releaser will only release if we haven't yet transferred @@ -1289,12 +1479,12 @@ class PyDenseResourceElementsAttribute } static void bindDerived(ClassTy &c) { - c.def_static("get_from_buffer", - PyDenseResourceElementsAttribute::getFromBuffer, - py::arg("array"), py::arg("name"), py::arg("type"), - py::arg("alignment") = py::none(), - py::arg("is_mutable") = false, py::arg("context") = py::none(), - kDenseResourceElementsAttrGetFromBufferDocstring); + c.def_static( + "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, + nb::arg("array"), nb::arg("name"), nb::arg("type"), + nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false, + nb::arg("context").none() = nb::none(), + kDenseResourceElementsAttrGetFromBufferDocstring); } }; @@ -1318,12 +1508,12 @@ class PyDictAttribute : public PyConcreteAttribute { c.def("__len__", &PyDictAttribute::dunderLen); c.def_static( "get", - [](py::dict attributes, DefaultingPyMlirContext context) { + [](nb::dict attributes, DefaultingPyMlirContext context) { SmallVector mlirNamedAttributes; mlirNamedAttributes.reserve(attributes.size()); - for (auto &it : attributes) { - auto &mlirAttr = it.second.cast(); - auto name = it.first.cast(); + for (std::pair it : attributes) { + auto &mlirAttr = nb::cast(it.second); + auto name = nb::cast(it.first); mlirNamedAttributes.push_back(mlirNamedAttributeGet( mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), toMlirStringRef(name)), @@ -1334,18 +1524,18 @@ class PyDictAttribute : public PyConcreteAttribute { mlirNamedAttributes.data()); return PyDictAttribute(context->getRef(), attr); }, - py::arg("value") = py::dict(), py::arg("context") = py::none(), + nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(), "Gets an uniqued dict attribute"); c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { MlirAttribute attr = mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) - throw py::key_error("attempt to access a non-existent attribute"); + throw nb::key_error("attempt to access a non-existent attribute"); return attr; }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { - throw py::index_error("attempt to access out of bounds attribute"); + throw nb::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); return PyNamedAttribute( @@ -1365,25 +1555,25 @@ class PyDenseFPElementsAttribute static constexpr const char *pyClassName = "DenseFPElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; - py::float_ dunderGetItem(intptr_t pos) { + nb::float_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw py::index_error("attempt to access out of bounds element"); + throw nb::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); type = mlirShapedTypeGetElementType(type); // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::float_ is implicitly constructible + // elemental type of the attribute. nb::float_ is implicitly constructible // from float and double. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. if (mlirTypeIsAF32(type)) { - return mlirDenseElementsAttrGetFloatValue(*this, pos); + return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); } if (mlirTypeIsAF64(type)) { - return mlirDenseElementsAttrGetDoubleValue(*this, pos); + return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); } - throw py::type_error("Unsupported floating-point type"); + throw nb::type_error("Unsupported floating-point type"); } static void bindDerived(ClassTy &c) { @@ -1406,9 +1596,9 @@ class PyTypeAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirTypeAttrGet(value.get()); return PyTypeAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets a uniqued Type attribute"); - c.def_property_readonly("value", [](PyTypeAttribute &self) { + c.def_prop_ro("value", [](PyTypeAttribute &self) { return mlirTypeAttrGetValue(self.get()); }); } @@ -1430,7 +1620,7 @@ class PyUnitAttribute : public PyConcreteAttribute { return PyUnitAttribute(context->getRef(), mlirUnitAttrGet(context->get())); }, - py::arg("context") = py::none(), "Create a Unit attribute."); + nb::arg("context").none() = nb::none(), "Create a Unit attribute."); } }; @@ -1453,7 +1643,8 @@ class PyStridedLayoutAttribute ctx->get(), offset, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); }, - py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), + nb::arg("offset"), nb::arg("strides"), + nb::arg("context").none() = nb::none(), "Gets a strided layout attribute."); c.def_static( "get_fully_dynamic", @@ -1465,16 +1656,17 @@ class PyStridedLayoutAttribute ctx->get(), dynamic, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); }, - py::arg("rank"), py::arg("context") = py::none(), - "Gets a strided layout attribute with dynamic offset and strides of a " + nb::arg("rank"), nb::arg("context").none() = nb::none(), + "Gets a strided layout attribute with dynamic offset and strides of " + "a " "given rank."); - c.def_property_readonly( + c.def_prop_ro( "offset", [](PyStridedLayoutAttribute &self) { return mlirStridedLayoutAttrGetOffset(self); }, "Returns the value of the float point attribute"); - c.def_property_readonly( + c.def_prop_ro( "strides", [](PyStridedLayoutAttribute &self) { intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); @@ -1488,63 +1680,64 @@ class PyStridedLayoutAttribute } }; -py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { +nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); + return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI8ArrayAttribute(pyAttribute)); if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI16ArrayAttribute(pyAttribute)); if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI32ArrayAttribute(pyAttribute)); if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI64ArrayAttribute(pyAttribute)); if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseF32ArrayAttribute(pyAttribute)); if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseF64ArrayAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown element type DenseArrayAttr (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); } -py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { +nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseFPElementsAttribute(pyAttribute)); + return nb::cast(PyDenseFPElementsAttribute(pyAttribute)); if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseIntElementsAttribute(pyAttribute)); + return nb::cast(PyDenseIntElementsAttribute(pyAttribute)); std::string msg = std::string( "Can't cast unknown element type DenseIntOrFPElementsAttr (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); } -py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { +nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { if (PyBoolAttribute::isaFunction(pyAttribute)) - return py::cast(PyBoolAttribute(pyAttribute)); + return nb::cast(PyBoolAttribute(pyAttribute)); if (PyIntegerAttribute::isaFunction(pyAttribute)) - return py::cast(PyIntegerAttribute(pyAttribute)); + return nb::cast(PyIntegerAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown element type DenseArrayAttr (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); } -py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { +nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) - return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); + return nb::cast(PyFlatSymbolRefAttribute(pyAttribute)); if (PySymbolRefAttribute::isaFunction(pyAttribute)) - return py::cast(PySymbolRefAttribute(pyAttribute)); + return nb::cast(PySymbolRefAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + + ")"; + throw nb::type_error(msg.c_str()); } } // namespace -void mlir::python::populateIRAttributes(py::module &m) { +void mlir::python::populateIRAttributes(nb::module_ &m) { PyAffineMapAttribute::bind(m); PyDenseBoolArrayAttribute::bind(m); PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); @@ -1562,24 +1755,26 @@ void mlir::python::populateIRAttributes(py::module &m) { PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); PyGlobals::get().registerTypeCaster( mlirDenseArrayAttrGetTypeID(), - pybind11::cpp_function(denseArrayAttributeCaster)); + nb::cast(nb::cpp_function(denseArrayAttributeCaster))); PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); PyBoolAttribute::bind(m); - PyDenseElementsAttribute::bind(m); + PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots); PyDenseFPElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirDenseIntOrFPElementsAttrGetTypeID(), - pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); + nb::cast( + nb::cpp_function(denseIntOrFPElementsAttributeCaster))); PyDenseResourceElementsAttribute::bind(m); PyDictAttribute::bind(m); PySymbolRefAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirSymbolRefAttrGetTypeID(), - pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); + nb::cast( + nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster))); PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); @@ -1590,7 +1785,7 @@ void mlir::python::populateIRAttributes(py::module &m) { PyTypeAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirIntegerAttrGetTypeID(), - pybind11::cpp_function(integerOrBoolAttributeCaster)); + nb::cast(nb::cpp_function(integerOrBoolAttributeCaster))); PyUnitAttribute::bind(m); PyStridedLayoutAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3e96f8c60ba7cd..ff4ad1a0d806c7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,26 +6,31 @@ // //===----------------------------------------------------------------------===// -#include "IRModule.h" +#include +#include +#include +#include +#include +#include -#include "Globals.h" -#include "PybindUtils.h" +#include +#include +#include "Globals.h" +#include "IRModule.h" +#include "NanobindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include -#include - -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlir; using namespace mlir::python; @@ -190,18 +195,18 @@ operations. /// Helper for creating an @classmethod. template -py::object classmethod(Func f, Args... args) { - py::object cf = py::cpp_function(f, args...); - return py::reinterpret_borrow((PyClassMethod_New(cf.ptr()))); +nb::object classmethod(Func f, Args... args) { + nb::object cf = nb::cpp_function(f, args...); + return nb::borrow((PyClassMethod_New(cf.ptr()))); } -static py::object +static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, - py::object dialectDescriptor) { + nb::object dialectDescriptor) { auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); if (!dialectClass) { // Use the base class. - return py::cast(PyDialect(std::move(dialectDescriptor))); + return nb::cast(PyDialect(std::move(dialectDescriptor))); } // Create the custom implementation. @@ -212,42 +217,47 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +static MlirStringRef toMlirStringRef(const nb::bytes &s) { + return mlirStringRefCreate(static_cast(s.data()), s.size()); +} + /// Create a block, using the current location context if no locations are /// specified. -static MlirBlock createBlock(const py::sequence &pyArgTypes, - const std::optional &pyArgLocs) { +static MlirBlock createBlock(const nb::sequence &pyArgTypes, + const std::optional &pyArgLocs) { SmallVector argTypes; - argTypes.reserve(pyArgTypes.size()); + argTypes.reserve(nb::len(pyArgTypes)); for (const auto &pyType : pyArgTypes) - argTypes.push_back(pyType.cast()); + argTypes.push_back(nb::cast(pyType)); SmallVector argLocs; if (pyArgLocs) { - argLocs.reserve(pyArgLocs->size()); + argLocs.reserve(nb::len(*pyArgLocs)); for (const auto &pyLoc : *pyArgLocs) - argLocs.push_back(pyLoc.cast()); + argLocs.push_back(nb::cast(pyLoc)); } else if (!argTypes.empty()) { argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); } if (argTypes.size() != argLocs.size()) - throw py::value_error(("Expected " + Twine(argTypes.size()) + + throw nb::value_error(("Expected " + Twine(argTypes.size()) + " locations, got: " + Twine(argLocs.size())) - .str()); + .str() + .c_str()); return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); } /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { - static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } + static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); } - static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } + static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); } - static void bind(py::module &m) { + static void bind(nb::module_ &m) { // Debug flags. - py::class_(m, "_GlobalDebug", py::module_local()) - .def_property_static("flag", &PyGlobalDebugFlag::get, - &PyGlobalDebugFlag::set, "LLVM-wide debug flag") + nb::class_(m, "_GlobalDebug") + .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, + &PyGlobalDebugFlag::set, "LLVM-wide debug flag") .def_static( "set_types", [](const std::string &type) { @@ -268,20 +278,20 @@ struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); } - static py::function dundeGetItemNamed(const std::string &attributeKind) { + static nb::callable dundeGetItemNamed(const std::string &attributeKind) { auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); if (!builder) - throw py::key_error(attributeKind); + throw nb::key_error(attributeKind.c_str()); return *builder; } static void dundeSetItemNamed(const std::string &attributeKind, - py::function func, bool replace) { + nb::callable func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); } - static void bind(py::module &m) { - py::class_(m, "AttrBuilder", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "AttrBuilder") .def_static("contains", &PyAttrBuilderMap::dunderContains) .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed, @@ -295,8 +305,8 @@ struct PyAttrBuilderMap { // PyBlock //------------------------------------------------------------------------------ -py::object PyBlock::getCapsule() { - return py::reinterpret_steal(mlirPythonBlockToCapsule(get())); +nb::object PyBlock::getCapsule() { + return nb::steal(mlirPythonBlockToCapsule(get())); } //------------------------------------------------------------------------------ @@ -315,14 +325,14 @@ class PyRegionIterator { PyRegion dunderNext() { operation->checkValid(); if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { - throw py::stop_iteration(); + throw nb::stop_iteration(); } MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); return PyRegion(operation, region); } - static void bind(py::module &m) { - py::class_(m, "RegionIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "RegionIterator") .def("__iter__", &PyRegionIterator::dunderIter) .def("__next__", &PyRegionIterator::dunderNext); } @@ -351,14 +361,14 @@ class PyRegionList { PyRegion dunderGetItem(intptr_t index) { // dunderLen checks validity. if (index < 0 || index >= dunderLen()) { - throw py::index_error("attempt to access out of bounds region"); + throw nb::index_error("attempt to access out of bounds region"); } MlirRegion region = mlirOperationGetRegion(operation->get(), index); return PyRegion(operation, region); } - static void bind(py::module &m) { - py::class_(m, "RegionSequence", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "RegionSequence") .def("__len__", &PyRegionList::dunderLen) .def("__iter__", &PyRegionList::dunderIter) .def("__getitem__", &PyRegionList::dunderGetItem); @@ -378,7 +388,7 @@ class PyBlockIterator { PyBlock dunderNext() { operation->checkValid(); if (mlirBlockIsNull(next)) { - throw py::stop_iteration(); + throw nb::stop_iteration(); } PyBlock returnBlock(operation, next); @@ -386,8 +396,8 @@ class PyBlockIterator { return returnBlock; } - static void bind(py::module &m) { - py::class_(m, "BlockIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "BlockIterator") .def("__iter__", &PyBlockIterator::dunderIter) .def("__next__", &PyBlockIterator::dunderNext); } @@ -424,7 +434,7 @@ class PyBlockList { PyBlock dunderGetItem(intptr_t index) { operation->checkValid(); if (index < 0) { - throw py::index_error("attempt to access out of bounds block"); + throw nb::index_error("attempt to access out of bounds block"); } MlirBlock block = mlirRegionGetFirstBlock(region); while (!mlirBlockIsNull(block)) { @@ -434,24 +444,26 @@ class PyBlockList { block = mlirBlockGetNextInRegion(block); index -= 1; } - throw py::index_error("attempt to access out of bounds block"); + throw nb::index_error("attempt to access out of bounds block"); } - PyBlock appendBlock(const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + PyBlock appendBlock(const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { operation->checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); mlirRegionAppendOwnedBlock(region, block); return PyBlock(operation, block); } - static void bind(py::module &m) { - py::class_(m, "BlockList", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "BlockList") .def("__getitem__", &PyBlockList::dunderGetItem) .def("__iter__", &PyBlockList::dunderIter) .def("__len__", &PyBlockList::dunderLen) .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, - py::arg("arg_locs") = std::nullopt); + nb::arg("args"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt); } private: @@ -466,10 +478,10 @@ class PyOperationIterator { PyOperationIterator &dunderIter() { return *this; } - py::object dunderNext() { + nb::object dunderNext() { parentOperation->checkValid(); if (mlirOperationIsNull(next)) { - throw py::stop_iteration(); + throw nb::stop_iteration(); } PyOperationRef returnOperation = @@ -478,8 +490,8 @@ class PyOperationIterator { return returnOperation->createOpView(); } - static void bind(py::module &m) { - py::class_(m, "OperationIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OperationIterator") .def("__iter__", &PyOperationIterator::dunderIter) .def("__next__", &PyOperationIterator::dunderNext); } @@ -515,10 +527,10 @@ class PyOperationList { return count; } - py::object dunderGetItem(intptr_t index) { + nb::object dunderGetItem(intptr_t index) { parentOperation->checkValid(); if (index < 0) { - throw py::index_error("attempt to access out of bounds operation"); + throw nb::index_error("attempt to access out of bounds operation"); } MlirOperation childOp = mlirBlockGetFirstOperation(block); while (!mlirOperationIsNull(childOp)) { @@ -529,11 +541,11 @@ class PyOperationList { childOp = mlirOperationGetNextInBlock(childOp); index -= 1; } - throw py::index_error("attempt to access out of bounds operation"); + throw nb::index_error("attempt to access out of bounds operation"); } - static void bind(py::module &m) { - py::class_(m, "OperationList", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OperationList") .def("__getitem__", &PyOperationList::dunderGetItem) .def("__iter__", &PyOperationList::dunderIter) .def("__len__", &PyOperationList::dunderLen); @@ -548,7 +560,7 @@ class PyOpOperand { public: PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} - py::object getOwner() { + nb::object getOwner() { MlirOperation owner = mlirOpOperandGetOwner(opOperand); PyMlirContextRef context = PyMlirContext::forContext(mlirOperationGetContext(owner)); @@ -557,11 +569,10 @@ class PyOpOperand { size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } - static void bind(py::module &m) { - py::class_(m, "OpOperand", py::module_local()) - .def_property_readonly("owner", &PyOpOperand::getOwner) - .def_property_readonly("operand_number", - &PyOpOperand::getOperandNumber); + static void bind(nb::module_ &m) { + nb::class_(m, "OpOperand") + .def_prop_ro("owner", &PyOpOperand::getOwner) + .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber); } private: @@ -576,15 +587,15 @@ class PyOpOperandIterator { PyOpOperand dunderNext() { if (mlirOpOperandIsNull(opOperand)) - throw py::stop_iteration(); + throw nb::stop_iteration(); PyOpOperand returnOpOperand(opOperand); opOperand = mlirOpOperandGetNextUse(opOperand); return returnOpOperand; } - static void bind(py::module &m) { - py::class_(m, "OpOperandIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OpOperandIterator") .def("__iter__", &PyOpOperandIterator::dunderIter) .def("__next__", &PyOpOperandIterator::dunderNext); } @@ -600,7 +611,7 @@ class PyOpOperandIterator { //------------------------------------------------------------------------------ PyMlirContext::PyMlirContext(MlirContext context) : context(context) { - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); liveContexts[context.ptr] = this; } @@ -609,41 +620,36 @@ PyMlirContext::~PyMlirContext() { // Note that the only public way to construct an instance is via the // forContext method, which always puts the associated handle into // liveContexts. - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; getLiveContexts().erase(context.ptr); mlirContextDestroy(context); } -py::object PyMlirContext::getCapsule() { - return py::reinterpret_steal(mlirPythonContextToCapsule(get())); +nb::object PyMlirContext::getCapsule() { + return nb::steal(mlirPythonContextToCapsule(get())); } -py::object PyMlirContext::createFromCapsule(py::object capsule) { +nb::object PyMlirContext::createFromCapsule(nb::object capsule) { MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); if (mlirContextIsNull(rawContext)) - throw py::error_already_set(); + throw nb::python_error(); return forContext(rawContext).releaseObject(); } -PyMlirContext *PyMlirContext::createNewContextForInit() { - MlirContext context = mlirContextCreateWithThreading(false); - return new PyMlirContext(context); -} - PyMlirContextRef PyMlirContext::forContext(MlirContext context) { - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { // Create. PyMlirContext *unownedContextWrapper = new PyMlirContext(context); - py::object pyRef = py::cast(unownedContextWrapper); - assert(pyRef && "cast to py::object failed"); + nb::object pyRef = nb::cast(unownedContextWrapper); + assert(pyRef && "cast to nb::object failed"); liveContexts[context.ptr] = unownedContextWrapper; return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); } // Use existing. - py::object pyRef = py::cast(it->second); + nb::object pyRef = nb::cast(it->second); return PyMlirContextRef(it->second, std::move(pyRef)); } @@ -717,23 +723,23 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } -pybind11::object PyMlirContext::contextEnter() { - return PyThreadContextEntry::pushContext(*this); +nb::object PyMlirContext::contextEnter(nb::object context) { + return PyThreadContextEntry::pushContext(context); } -void PyMlirContext::contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { +void PyMlirContext::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { PyThreadContextEntry::popContext(*this); } -py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { +nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { // Note that ownership is transferred to the delete callback below by way of // an explicit inc_ref (borrow). PyDiagnosticHandler *pyHandler = new PyDiagnosticHandler(get(), std::move(callback)); - py::object pyHandlerObject = - py::cast(pyHandler, py::return_value_policy::take_ownership); + nb::object pyHandlerObject = + nb::cast(pyHandler, nb::rv_policy::take_ownership); pyHandlerObject.inc_ref(); // In these C callbacks, the userData is a PyDiagnosticHandler* that is @@ -741,17 +747,17 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { auto handlerCallback = +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); - py::object pyDiagnosticObject = - py::cast(pyDiagnostic, py::return_value_policy::take_ownership); + nb::object pyDiagnosticObject = + nb::cast(pyDiagnostic, nb::rv_policy::take_ownership); auto *pyHandler = static_cast(userData); bool result = false; { // Since this can be called from arbitrary C++ contexts, always get the // gil. - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; try { - result = py::cast(pyHandler->callback(pyDiagnostic)); + result = nb::cast(pyHandler->callback(pyDiagnostic)); } catch (std::exception &e) { fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", e.what()); @@ -768,8 +774,7 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { pyHandler->registeredID.reset(); // Decrement reference, balancing the inc_ref() above. - py::object pyHandlerObject = - py::cast(pyHandler, py::return_value_policy::reference); + nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference); pyHandlerObject.dec_ref(); }; @@ -819,9 +824,9 @@ PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { return &stack.back(); } -void PyThreadContextEntry::push(FrameKind frameKind, py::object context, - py::object insertionPoint, - py::object location) { +void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, + nb::object insertionPoint, + nb::object location) { auto &stack = getStack(); stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), std::move(location)); @@ -844,19 +849,19 @@ void PyThreadContextEntry::push(FrameKind frameKind, py::object context, PyMlirContext *PyThreadContextEntry::getContext() { if (!context) return nullptr; - return py::cast(context); + return nb::cast(context); } PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { if (!insertionPoint) return nullptr; - return py::cast(insertionPoint); + return nb::cast(insertionPoint); } PyLocation *PyThreadContextEntry::getLocation() { if (!location) return nullptr; - return py::cast(location); + return nb::cast(location); } PyMlirContext *PyThreadContextEntry::getDefaultContext() { @@ -874,12 +879,11 @@ PyLocation *PyThreadContextEntry::getDefaultLocation() { return tos ? tos->getLocation() : nullptr; } -py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { - py::object contextObj = py::cast(context); - push(FrameKind::Context, /*context=*/contextObj, - /*insertionPoint=*/py::object(), - /*location=*/py::object()); - return contextObj; +nb::object PyThreadContextEntry::pushContext(nb::object context) { + push(FrameKind::Context, /*context=*/context, + /*insertionPoint=*/nb::object(), + /*location=*/nb::object()); + return context; } void PyThreadContextEntry::popContext(PyMlirContext &context) { @@ -892,15 +896,16 @@ void PyThreadContextEntry::popContext(PyMlirContext &context) { stack.pop_back(); } -py::object -PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { - py::object contextObj = +nb::object +PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) { + PyInsertionPoint &insertionPoint = + nb::cast(insertionPointObj); + nb::object contextObj = insertionPoint.getBlock().getParentOperation()->getContext().getObject(); - py::object insertionPointObj = py::cast(insertionPoint); push(FrameKind::InsertionPoint, /*context=*/contextObj, /*insertionPoint=*/insertionPointObj, - /*location=*/py::object()); + /*location=*/nb::object()); return insertionPointObj; } @@ -915,11 +920,11 @@ void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { stack.pop_back(); } -py::object PyThreadContextEntry::pushLocation(PyLocation &location) { - py::object contextObj = location.getContext().getObject(); - py::object locationObj = py::cast(location); +nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) { + PyLocation &location = nb::cast(locationObj); + nb::object contextObj = location.getContext().getObject(); push(FrameKind::Location, /*context=*/contextObj, - /*insertionPoint=*/py::object(), + /*insertionPoint=*/nb::object(), /*location=*/locationObj); return locationObj; } @@ -941,15 +946,15 @@ void PyThreadContextEntry::popLocation(PyLocation &location) { void PyDiagnostic::invalidate() { valid = false; if (materializedNotes) { - for (auto ¬eObject : *materializedNotes) { - PyDiagnostic *note = py::cast(noteObject); + for (nb::handle noteObject : *materializedNotes) { + PyDiagnostic *note = nb::cast(noteObject); note->invalidate(); } } } PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, - py::object callback) + nb::object callback) : context(context), callback(std::move(callback)) {} PyDiagnosticHandler::~PyDiagnosticHandler() = default; @@ -984,32 +989,36 @@ PyLocation PyDiagnostic::getLocation() { return PyLocation(PyMlirContext::forContext(context), loc); } -py::str PyDiagnostic::getMessage() { +nb::str PyDiagnostic::getMessage() { checkValid(); - py::object fileObject = py::module::import("io").attr("StringIO")(); + nb::object fileObject = nb::module_::import_("io").attr("StringIO")(); PyFileAccumulator accum(fileObject, /*binary=*/false); mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); - return fileObject.attr("getvalue")(); + return nb::cast(fileObject.attr("getvalue")()); } -py::tuple PyDiagnostic::getNotes() { +nb::tuple PyDiagnostic::getNotes() { checkValid(); if (materializedNotes) return *materializedNotes; intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); - materializedNotes = py::tuple(numNotes); + nb::tuple notes = nb::steal(PyTuple_New(numNotes)); for (intptr_t i = 0; i < numNotes; ++i) { MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); - (*materializedNotes)[i] = PyDiagnostic(noteDiag); + nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag)); + PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr()); } + materializedNotes = std::move(notes); + return *materializedNotes; } PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { std::vector notes; - for (py::handle n : getNotes()) - notes.emplace_back(n.cast().getInfo()); - return {getSeverity(), getLocation(), getMessage(), std::move(notes)}; + for (nb::handle n : getNotes()) + notes.emplace_back(nb::cast(n).getInfo()); + return {getSeverity(), getLocation(), nb::cast(getMessage()), + std::move(notes)}; } //------------------------------------------------------------------------------ @@ -1023,22 +1032,21 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key, if (mlirDialectIsNull(dialect)) { std::string msg = (Twine("Dialect '") + key + "' not found").str(); if (attrError) - throw py::attribute_error(msg); - throw py::index_error(msg); + throw nb::attribute_error(msg.c_str()); + throw nb::index_error(msg.c_str()); } return dialect; } -py::object PyDialectRegistry::getCapsule() { - return py::reinterpret_steal( - mlirPythonDialectRegistryToCapsule(*this)); +nb::object PyDialectRegistry::getCapsule() { + return nb::steal(mlirPythonDialectRegistryToCapsule(*this)); } -PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { +PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) { MlirDialectRegistry rawRegistry = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); if (mlirDialectRegistryIsNull(rawRegistry)) - throw py::error_already_set(); + throw nb::python_error(); return PyDialectRegistry(rawRegistry); } @@ -1046,25 +1054,25 @@ PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { // PyLocation //------------------------------------------------------------------------------ -py::object PyLocation::getCapsule() { - return py::reinterpret_steal(mlirPythonLocationToCapsule(*this)); +nb::object PyLocation::getCapsule() { + return nb::steal(mlirPythonLocationToCapsule(*this)); } -PyLocation PyLocation::createFromCapsule(py::object capsule) { +PyLocation PyLocation::createFromCapsule(nb::object capsule) { MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); if (mlirLocationIsNull(rawLoc)) - throw py::error_already_set(); + throw nb::python_error(); return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), rawLoc); } -py::object PyLocation::contextEnter() { - return PyThreadContextEntry::pushLocation(*this); +nb::object PyLocation::contextEnter(nb::object locationObj) { + return PyThreadContextEntry::pushLocation(locationObj); } -void PyLocation::contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { +void PyLocation::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { PyThreadContextEntry::popLocation(*this); } @@ -1087,7 +1095,7 @@ PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) : BaseContextObject(std::move(contextRef)), module(module) {} PyModule::~PyModule() { - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveModules = getContext()->liveModules; assert(liveModules.count(module.ptr) == 1 && "destroying module not in live map"); @@ -1099,7 +1107,7 @@ PyModuleRef PyModule::forModule(MlirModule module) { MlirContext context = mlirModuleGetContext(module); PyMlirContextRef contextRef = PyMlirContext::forContext(context); - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveModules = contextRef->liveModules; auto it = liveModules.find(module.ptr); if (it == liveModules.end()) { @@ -1108,8 +1116,7 @@ PyModuleRef PyModule::forModule(MlirModule module) { // Note that the default return value policy on cast is automatic_reference, // which does not take ownership (delete will not be called). // Just be explicit. - py::object pyRef = - py::cast(unownedModule, py::return_value_policy::take_ownership); + nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); unownedModule->handle = pyRef; liveModules[module.ptr] = std::make_pair(unownedModule->handle, unownedModule); @@ -1117,19 +1124,19 @@ PyModuleRef PyModule::forModule(MlirModule module) { } // Use existing. PyModule *existing = it->second.second; - py::object pyRef = py::reinterpret_borrow(it->second.first); + nb::object pyRef = nb::borrow(it->second.first); return PyModuleRef(existing, std::move(pyRef)); } -py::object PyModule::createFromCapsule(py::object capsule) { +nb::object PyModule::createFromCapsule(nb::object capsule) { MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); if (mlirModuleIsNull(rawModule)) - throw py::error_already_set(); + throw nb::python_error(); return forModule(rawModule).releaseObject(); } -py::object PyModule::getCapsule() { - return py::reinterpret_steal(mlirPythonModuleToCapsule(get())); +nb::object PyModule::getCapsule() { + return nb::steal(mlirPythonModuleToCapsule(get())); } //------------------------------------------------------------------------------ @@ -1158,7 +1165,7 @@ PyOperation::~PyOperation() { PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, MlirOperation operation, - py::object parentKeepAlive) { + nb::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; // Create. PyOperation *unownedOperation = @@ -1166,8 +1173,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, // Note that the default return value policy on cast is automatic_reference, // which does not take ownership (delete will not be called). // Just be explicit. - py::object pyRef = - py::cast(unownedOperation, py::return_value_policy::take_ownership); + nb::object pyRef = nb::cast(unownedOperation, nb::rv_policy::take_ownership); unownedOperation->handle = pyRef; if (parentKeepAlive) { unownedOperation->parentKeepAlive = std::move(parentKeepAlive); @@ -1178,7 +1184,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, - py::object parentKeepAlive) { + nb::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; auto it = liveOperations.find(operation.ptr); if (it == liveOperations.end()) { @@ -1188,13 +1194,13 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, } // Use existing. PyOperation *existing = it->second.second; - py::object pyRef = py::reinterpret_borrow(it->second.first); + nb::object pyRef = nb::borrow(it->second.first); return PyOperationRef(existing, std::move(pyRef)); } PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, - py::object parentKeepAlive) { + nb::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; assert(liveOperations.count(operation.ptr) == 0 && "cannot create detached operation that already exists"); @@ -1227,12 +1233,12 @@ void PyOperation::checkValid() const { void PyOperationBase::print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, py::object fileObject, + bool assumeVerified, nb::object fileObject, bool binary, bool skipRegions) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) - fileObject = py::module::import("sys").attr("stdout"); + fileObject = nb::module_::import_("sys").attr("stdout"); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); if (largeElementsLimit) @@ -1255,18 +1261,18 @@ void PyOperationBase::print(std::optional largeElementsLimit, mlirOpPrintingFlagsDestroy(flags); } -void PyOperationBase::print(PyAsmState &state, py::object fileObject, +void PyOperationBase::print(PyAsmState &state, nb::object fileObject, bool binary) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) - fileObject = py::module::import("sys").attr("stdout"); + fileObject = nb::module_::import_("sys").attr("stdout"); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), accum.getUserData()); } -void PyOperationBase::writeBytecode(const py::object &fileObject, +void PyOperationBase::writeBytecode(const nb::object &fileObject, std::optional bytecodeVersion) { PyOperation &operation = getOperation(); operation.checkValid(); @@ -1282,9 +1288,10 @@ void PyOperationBase::writeBytecode(const py::object &fileObject, operation, config, accum.getCallback(), accum.getUserData()); mlirBytecodeWriterConfigDestroy(config); if (mlirLogicalResultIsFailure(res)) - throw py::value_error((Twine("Unable to honor desired bytecode version ") + + throw nb::value_error((Twine("Unable to honor desired bytecode version ") + Twine(*bytecodeVersion)) - .str()); + .str() + .c_str()); } void PyOperationBase::walk( @@ -1296,7 +1303,7 @@ void PyOperationBase::walk( std::function callback; bool gotException; std::string exceptionWhat; - py::object exceptionType; + nb::object exceptionType; }; UserData userData{callback, false, {}, {}}; MlirOperationWalkCallback walkCallback = [](MlirOperation op, @@ -1304,10 +1311,10 @@ void PyOperationBase::walk( UserData *calleeUserData = static_cast(userData); try { return (calleeUserData->callback)(op); - } catch (py::error_already_set &e) { + } catch (nb::python_error &e) { calleeUserData->gotException = true; - calleeUserData->exceptionWhat = e.what(); - calleeUserData->exceptionType = e.type(); + calleeUserData->exceptionWhat = std::string(e.what()); + calleeUserData->exceptionType = nb::borrow(e.type()); return MlirWalkResult::MlirWalkResultInterrupt; } }; @@ -1319,16 +1326,16 @@ void PyOperationBase::walk( } } -py::object PyOperationBase::getAsm(bool binary, +nb::object PyOperationBase::getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions) { - py::object fileObject; + nb::object fileObject; if (binary) { - fileObject = py::module::import("io").attr("BytesIO")(); + fileObject = nb::module_::import_("io").attr("BytesIO")(); } else { - fileObject = py::module::import("io").attr("StringIO")(); + fileObject = nb::module_::import_("io").attr("StringIO")(); } print(/*largeElementsLimit=*/largeElementsLimit, /*enableDebugInfo=*/enableDebugInfo, @@ -1372,7 +1379,7 @@ bool PyOperationBase::verify() { std::optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) - throw py::value_error("Detached operations have no parent"); + throw nb::value_error("Detached operations have no parent"); MlirOperation operation = mlirOperationGetParentOperation(get()); if (mlirOperationIsNull(operation)) return {}; @@ -1388,42 +1395,42 @@ PyBlock PyOperation::getBlock() { return PyBlock{std::move(*parentOperation), block}; } -py::object PyOperation::getCapsule() { +nb::object PyOperation::getCapsule() { checkValid(); - return py::reinterpret_steal(mlirPythonOperationToCapsule(get())); + return nb::steal(mlirPythonOperationToCapsule(get())); } -py::object PyOperation::createFromCapsule(py::object capsule) { +nb::object PyOperation::createFromCapsule(nb::object capsule) { MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); if (mlirOperationIsNull(rawOperation)) - throw py::error_already_set(); + throw nb::python_error(); MlirContext rawCtxt = mlirOperationGetContext(rawOperation); return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) .releaseObject(); } static void maybeInsertOperation(PyOperationRef &op, - const py::object &maybeIp) { + const nb::object &maybeIp) { // InsertPoint active? - if (!maybeIp.is(py::cast(false))) { + if (!maybeIp.is(nb::cast(false))) { PyInsertionPoint *ip; if (maybeIp.is_none()) { ip = PyThreadContextEntry::getDefaultInsertionPoint(); } else { - ip = py::cast(maybeIp); + ip = nb::cast(maybeIp); } if (ip) ip->insert(*op.get()); } } -py::object PyOperation::create(const std::string &name, +nb::object PyOperation::create(const std::string &name, std::optional> results, std::optional> operands, - std::optional attributes, + std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, - const py::object &maybeIp, bool inferType) { + const nb::object &maybeIp, bool inferType) { llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1431,14 +1438,14 @@ py::object PyOperation::create(const std::string &name, // General parameter validation. if (regions < 0) - throw py::value_error("number of regions must be >= 0"); + throw nb::value_error("number of regions must be >= 0"); // Unpack/validate operands. if (operands) { mlirOperands.reserve(operands->size()); for (PyValue *operand : *operands) { if (!operand) - throw py::value_error("operand value cannot be None"); + throw nb::value_error("operand value cannot be None"); mlirOperands.push_back(operand->get()); } } @@ -1449,38 +1456,38 @@ py::object PyOperation::create(const std::string &name, for (PyType *result : *results) { // TODO: Verify result type originate from the same context. if (!result) - throw py::value_error("result type cannot be None"); + throw nb::value_error("result type cannot be None"); mlirResults.push_back(*result); } } // Unpack/validate attributes. if (attributes) { mlirAttributes.reserve(attributes->size()); - for (auto &it : *attributes) { + for (std::pair it : *attributes) { std::string key; try { - key = it.first.cast(); - } catch (py::cast_error &err) { + key = nb::cast(it.first); + } catch (nb::cast_error &err) { std::string msg = "Invalid attribute key (not a string) when " "attempting to create the operation \"" + name + "\" (" + err.what() + ")"; - throw py::cast_error(msg); + throw nb::type_error(msg.c_str()); } try { - auto &attribute = it.second.cast(); + auto &attribute = nb::cast(it.second); // TODO: Verify attribute originates from the same context. mlirAttributes.emplace_back(std::move(key), attribute); - } catch (py::reference_cast_error &) { + } catch (nb::cast_error &err) { + std::string msg = "Invalid attribute value for the key \"" + key + + "\" when attempting to create the operation \"" + + name + "\" (" + err.what() + ")"; + throw nb::type_error(msg.c_str()); + } catch (std::runtime_error &err) { // This exception seems thrown when the value is "None". std::string msg = "Found an invalid (`None`?) attribute value for the key \"" + key + "\" when attempting to create the operation \"" + name + "\""; - throw py::cast_error(msg); - } catch (py::cast_error &err) { - std::string msg = "Invalid attribute value for the key \"" + key + - "\" when attempting to create the operation \"" + - name + "\" (" + err.what() + ")"; - throw py::cast_error(msg); + throw std::runtime_error(msg); } } } @@ -1490,7 +1497,7 @@ py::object PyOperation::create(const std::string &name, for (auto *successor : *successors) { // TODO: Verify successor originate from the same context. if (!successor) - throw py::value_error("successor block cannot be None"); + throw nb::value_error("successor block cannot be None"); mlirSuccessors.push_back(successor->get()); } } @@ -1535,7 +1542,7 @@ py::object PyOperation::create(const std::string &name, // Construct the operation. MlirOperation operation = mlirOperationCreate(&state); if (!operation.ptr) - throw py::value_error("Operation creation failed"); + throw nb::value_error("Operation creation failed"); PyOperationRef created = PyOperation::createDetached(location->getContext(), operation); maybeInsertOperation(created, maybeIp); @@ -1543,7 +1550,7 @@ py::object PyOperation::create(const std::string &name, return created.getObject(); } -py::object PyOperation::clone(const py::object &maybeIp) { +nb::object PyOperation::clone(const nb::object &maybeIp) { MlirOperation clonedOperation = mlirOperationClone(operation); PyOperationRef cloned = PyOperation::createDetached(getContext(), clonedOperation); @@ -1552,15 +1559,15 @@ py::object PyOperation::clone(const py::object &maybeIp) { return cloned->createOpView(); } -py::object PyOperation::createOpView() { +nb::object PyOperation::createOpView() { checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); MlirStringRef identStr = mlirIdentifierStr(ident); auto operationCls = PyGlobals::get().lookupOperationClass( StringRef(identStr.data, identStr.length)); if (operationCls) - return PyOpView::constructDerived(*operationCls, *getRef().get()); - return py::cast(PyOpView(getRef().getObject())); + return PyOpView::constructDerived(*operationCls, getRef().getObject()); + return nb::cast(PyOpView(getRef().getObject())); } void PyOperation::erase() { @@ -1573,8 +1580,8 @@ void PyOperation::erase() { // PyOpView //------------------------------------------------------------------------------ -static void populateResultTypes(StringRef name, py::list resultTypeList, - const py::object &resultSegmentSpecObj, +static void populateResultTypes(StringRef name, nb::list resultTypeList, + const nb::object &resultSegmentSpecObj, std::vector &resultSegmentLengths, std::vector &resultTypes) { resultTypes.reserve(resultTypeList.size()); @@ -1582,26 +1589,28 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, // Non-variadic result unpacking. for (const auto &it : llvm::enumerate(resultTypeList)) { try { - resultTypes.push_back(py::cast(it.value())); + resultTypes.push_back(nb::cast(it.value())); if (!resultTypes.back()) - throw py::cast_error(); - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Result ") + + throw nb::cast_error(); + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Type (" + err.what() + ")") - .str()); + .str() + .c_str()); } } } else { // Sized result unpacking. - auto resultSegmentSpec = py::cast>(resultSegmentSpecObj); + auto resultSegmentSpec = nb::cast>(resultSegmentSpecObj); if (resultSegmentSpec.size() != resultTypeList.size()) { - throw py::value_error((llvm::Twine("Operation \"") + name + + throw nb::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(resultSegmentSpec.size()) + " result segments but was provided " + llvm::Twine(resultTypeList.size())) - .str()); + .str() + .c_str()); } resultSegmentLengths.reserve(resultTypeList.size()); for (const auto &it : @@ -1610,7 +1619,7 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto *resultType = py::cast(std::get<0>(it.value())); + auto *resultType = nb::cast(std::get<0>(it.value())); if (resultType) { resultTypes.push_back(resultType); resultSegmentLengths.push_back(1); @@ -1618,14 +1627,20 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, // Allowed to be optional. resultSegmentLengths.push_back(0); } else { - throw py::cast_error("was None and result is not optional"); + throw nb::value_error( + (llvm::Twine("Result ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Type (was None and result is not optional)") + .str() + .c_str()); } - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Result ") + + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Type (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else if (segmentSpec == -1) { // Unpack sequence by appending. @@ -1635,72 +1650,75 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, resultSegmentLengths.push_back(0); } else { // Unpack the list. - auto segment = py::cast(std::get<0>(it.value())); - for (py::object segmentItem : segment) { - resultTypes.push_back(py::cast(segmentItem)); + auto segment = nb::cast(std::get<0>(it.value())); + for (nb::handle segmentItem : segment) { + resultTypes.push_back(nb::cast(segmentItem)); if (!resultTypes.back()) { - throw py::cast_error("contained a None item"); + throw nb::type_error("contained a None item"); } } - resultSegmentLengths.push_back(segment.size()); + resultSegmentLengths.push_back(nb::len(segment)); } } catch (std::exception &err) { // NOTE: Sloppy to be using a catch-all here, but there are at least // three different unrelated exceptions that can be thrown in the // above "casts". Just keep the scope above small and catch them all. - throw py::value_error((llvm::Twine("Result ") + + throw nb::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Sequence of Types (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else { - throw py::value_error("Unexpected segment spec"); + throw nb::value_error("Unexpected segment spec"); } } } } -py::object PyOpView::buildGeneric( - const py::object &cls, std::optional resultTypeList, - py::list operandList, std::optional attributes, +nb::object PyOpView::buildGeneric( + const nb::object &cls, std::optional resultTypeList, + nb::list operandList, std::optional attributes, std::optional> successors, std::optional regions, DefaultingPyLocation location, - const py::object &maybeIp) { + const nb::object &maybeIp) { PyMlirContextRef context = location->getContext(); // Class level operation construction metadata. - std::string name = py::cast(cls.attr("OPERATION_NAME")); + std::string name = nb::cast(cls.attr("OPERATION_NAME")); // Operand and result segment specs are either none, which does no // variadic unpacking, or a list of ints with segment sizes, where each // element is either a positive number (typically 1 for a scalar) or -1 to // indicate that it is derived from the length of the same-indexed operand // or result (implying that it is a list at that position). - py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); - py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); + nb::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); + nb::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); std::vector operandSegmentLengths; std::vector resultSegmentLengths; // Validate/determine region count. - auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); + auto opRegionSpec = nb::cast>(cls.attr("_ODS_REGIONS")); int opMinRegionCount = std::get<0>(opRegionSpec); bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); if (!regions) { regions = opMinRegionCount; } if (*regions < opMinRegionCount) { - throw py::value_error( + throw nb::value_error( (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + llvm::Twine(opMinRegionCount) + " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); + .str() + .c_str()); } if (opHasNoVariadicRegions && *regions > opMinRegionCount) { - throw py::value_error( + throw nb::value_error( (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + llvm::Twine(opMinRegionCount) + " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); + .str() + .c_str()); } // Unpack results. @@ -1717,26 +1735,28 @@ py::object PyOpView::buildGeneric( // Non-sized operand unpacking. for (const auto &it : llvm::enumerate(operandList)) { try { - operands.push_back(py::cast(it.value())); + operands.push_back(nb::cast(it.value())); if (!operands.back()) - throw py::cast_error(); - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + + throw nb::cast_error(); + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Value (" + err.what() + ")") - .str()); + .str() + .c_str()); } } } else { // Sized operand unpacking. - auto operandSegmentSpec = py::cast>(operandSegmentSpecObj); + auto operandSegmentSpec = nb::cast>(operandSegmentSpecObj); if (operandSegmentSpec.size() != operandList.size()) { - throw py::value_error((llvm::Twine("Operation \"") + name + + throw nb::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(operandSegmentSpec.size()) + "operand segments but was provided " + llvm::Twine(operandList.size())) - .str()); + .str() + .c_str()); } operandSegmentLengths.reserve(operandList.size()); for (const auto &it : @@ -1745,7 +1765,7 @@ py::object PyOpView::buildGeneric( if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto *operandValue = py::cast(std::get<0>(it.value())); + auto *operandValue = nb::cast(std::get<0>(it.value())); if (operandValue) { operands.push_back(operandValue); operandSegmentLengths.push_back(1); @@ -1753,14 +1773,20 @@ py::object PyOpView::buildGeneric( // Allowed to be optional. operandSegmentLengths.push_back(0); } else { - throw py::cast_error("was None and operand is not optional"); + throw nb::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (was None and operand is not optional)") + .str() + .c_str()); } - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Value (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else if (segmentSpec == -1) { // Unpack sequence by appending. @@ -1770,27 +1796,28 @@ py::object PyOpView::buildGeneric( operandSegmentLengths.push_back(0); } else { // Unpack the list. - auto segment = py::cast(std::get<0>(it.value())); - for (py::object segmentItem : segment) { - operands.push_back(py::cast(segmentItem)); + auto segment = nb::cast(std::get<0>(it.value())); + for (nb::handle segmentItem : segment) { + operands.push_back(nb::cast(segmentItem)); if (!operands.back()) { - throw py::cast_error("contained a None item"); + throw nb::type_error("contained a None item"); } } - operandSegmentLengths.push_back(segment.size()); + operandSegmentLengths.push_back(nb::len(segment)); } } catch (std::exception &err) { // NOTE: Sloppy to be using a catch-all here, but there are at least // three different unrelated exceptions that can be thrown in the // above "casts". Just keep the scope above small and catch them all. - throw py::value_error((llvm::Twine("Operand ") + + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Sequence of Values (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else { - throw py::value_error("Unexpected segment spec"); + throw nb::value_error("Unexpected segment spec"); } } } @@ -1799,13 +1826,13 @@ py::object PyOpView::buildGeneric( if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { // Dup. if (attributes) { - attributes = py::dict(*attributes); + attributes = nb::dict(*attributes); } else { - attributes = py::dict(); + attributes = nb::dict(); } if (attributes->contains("resultSegmentSizes") || attributes->contains("operandSegmentSizes")) { - throw py::value_error("Manually setting a 'resultSegmentSizes' or " + throw nb::value_error("Manually setting a 'resultSegmentSizes' or " "'operandSegmentSizes' attribute is unsupported. " "Use Operation.create for such low-level access."); } @@ -1839,21 +1866,18 @@ py::object PyOpView::buildGeneric( !resultTypeList); } -pybind11::object PyOpView::constructDerived(const pybind11::object &cls, - const PyOperation &operation) { - // TODO: pybind11 2.6 supports a more direct form. - // Upgrade many years from now. - // auto opViewType = py::type::of(); - py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true); - py::object instance = cls.attr("__new__")(cls); +nb::object PyOpView::constructDerived(const nb::object &cls, + const nb::object &operation) { + nb::handle opViewType = nb::type(); + nb::object instance = cls.attr("__new__")(cls); opViewType.attr("__init__")(instance, operation); return instance; } -PyOpView::PyOpView(const py::object &operationObject) +PyOpView::PyOpView(const nb::object &operationObject) // Casting through the PyOperationBase base-class and then back to the // Operation lets us accept any PyOperationBase subclass. - : operation(py::cast(operationObject).getOperation()), + : operation(nb::cast(operationObject).getOperation()), operationObject(operation.getRef().getObject()) {} //------------------------------------------------------------------------------ @@ -1869,7 +1893,7 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) void PyInsertionPoint::insert(PyOperationBase &operationBase) { PyOperation &operation = operationBase.getOperation(); if (operation.isAttached()) - throw py::value_error( + throw nb::value_error( "Attempt to insert operation that is already attached"); block.getParentOperation()->checkValid(); MlirOperation beforeOp = {nullptr}; @@ -1882,7 +1906,7 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) { // already end in a known terminator (violating this will cause assertion // failures later). if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { - throw py::index_error("Cannot insert operation at the end of a block " + throw nb::index_error("Cannot insert operation at the end of a block " "that already has a terminator. Did you mean to " "use 'InsertionPoint.at_block_terminator(block)' " "versus 'InsertionPoint(block)'?"); @@ -1908,19 +1932,19 @@ PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { MlirOperation terminator = mlirBlockGetTerminator(block.get()); if (mlirOperationIsNull(terminator)) - throw py::value_error("Block has no terminator"); + throw nb::value_error("Block has no terminator"); PyOperationRef terminatorOpRef = PyOperation::forOperation( block.getParentOperation()->getContext(), terminator); return PyInsertionPoint{block, std::move(terminatorOpRef)}; } -py::object PyInsertionPoint::contextEnter() { - return PyThreadContextEntry::pushInsertionPoint(*this); +nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { + return PyThreadContextEntry::pushInsertionPoint(insertPoint); } -void PyInsertionPoint::contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { +void PyInsertionPoint::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { PyThreadContextEntry::popInsertionPoint(*this); } @@ -1932,14 +1956,14 @@ bool PyAttribute::operator==(const PyAttribute &other) const { return mlirAttributeEqual(attr, other.attr); } -py::object PyAttribute::getCapsule() { - return py::reinterpret_steal(mlirPythonAttributeToCapsule(*this)); +nb::object PyAttribute::getCapsule() { + return nb::steal(mlirPythonAttributeToCapsule(*this)); } -PyAttribute PyAttribute::createFromCapsule(py::object capsule) { +PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); if (mlirAttributeIsNull(rawAttr)) - throw py::error_already_set(); + throw nb::python_error(); return PyAttribute( PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); } @@ -1964,14 +1988,14 @@ bool PyType::operator==(const PyType &other) const { return mlirTypeEqual(type, other.type); } -py::object PyType::getCapsule() { - return py::reinterpret_steal(mlirPythonTypeToCapsule(*this)); +nb::object PyType::getCapsule() { + return nb::steal(mlirPythonTypeToCapsule(*this)); } -PyType PyType::createFromCapsule(py::object capsule) { +PyType PyType::createFromCapsule(nb::object capsule) { MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); if (mlirTypeIsNull(rawType)) - throw py::error_already_set(); + throw nb::python_error(); return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), rawType); } @@ -1980,14 +2004,14 @@ PyType PyType::createFromCapsule(py::object capsule) { // PyTypeID. //------------------------------------------------------------------------------ -py::object PyTypeID::getCapsule() { - return py::reinterpret_steal(mlirPythonTypeIDToCapsule(*this)); +nb::object PyTypeID::getCapsule() { + return nb::steal(mlirPythonTypeIDToCapsule(*this)); } -PyTypeID PyTypeID::createFromCapsule(py::object capsule) { +PyTypeID PyTypeID::createFromCapsule(nb::object capsule) { MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); if (mlirTypeIDIsNull(mlirTypeID)) - throw py::error_already_set(); + throw nb::python_error(); return PyTypeID(mlirTypeID); } bool PyTypeID::operator==(const PyTypeID &other) const { @@ -1998,36 +2022,36 @@ bool PyTypeID::operator==(const PyTypeID &other) const { // PyValue and subclasses. //------------------------------------------------------------------------------ -pybind11::object PyValue::getCapsule() { - return py::reinterpret_steal(mlirPythonValueToCapsule(get())); +nb::object PyValue::getCapsule() { + return nb::steal(mlirPythonValueToCapsule(get())); } -pybind11::object PyValue::maybeDownCast() { +nb::object PyValue::maybeDownCast() { MlirType type = mlirValueGetType(get()); MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional valueCaster = + std::optional valueCaster = PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); - // py::return_value_policy::move means use std::move to move the return value + // nb::rv_policy::move means use std::move to move the return value // contents into a new instance that will be owned by Python. - py::object thisObj = py::cast(this, py::return_value_policy::move); + nb::object thisObj = nb::cast(this, nb::rv_policy::move); if (!valueCaster) return thisObj; return valueCaster.value()(thisObj); } -PyValue PyValue::createFromCapsule(pybind11::object capsule) { +PyValue PyValue::createFromCapsule(nb::object capsule) { MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); if (mlirValueIsNull(value)) - throw py::error_already_set(); + throw nb::python_error(); MlirOperation owner; if (mlirValueIsAOpResult(value)) owner = mlirOpResultGetOwner(value); if (mlirValueIsABlockArgument(value)) owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); if (mlirOperationIsNull(owner)) - throw py::error_already_set(); + throw nb::python_error(); MlirContext ctx = mlirOperationGetContext(owner); PyOperationRef ownerRef = PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); @@ -2042,16 +2066,17 @@ PySymbolTable::PySymbolTable(PyOperationBase &operation) : operation(operation.getOperation().getRef()) { symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); if (mlirSymbolTableIsNull(symbolTable)) { - throw py::cast_error("Operation is not a Symbol Table."); + throw nb::type_error("Operation is not a Symbol Table."); } } -py::object PySymbolTable::dunderGetItem(const std::string &name) { +nb::object PySymbolTable::dunderGetItem(const std::string &name) { operation->checkValid(); MlirOperation symbol = mlirSymbolTableLookup( symbolTable, mlirStringRefCreate(name.data(), name.length())); if (mlirOperationIsNull(symbol)) - throw py::key_error("Symbol '" + name + "' not in the symbol table."); + throw nb::key_error( + ("Symbol '" + name + "' not in the symbol table.").c_str()); return PyOperation::forOperation(operation->getContext(), symbol, operation.getObject()) @@ -2069,8 +2094,8 @@ void PySymbolTable::erase(PyOperationBase &symbol) { } void PySymbolTable::dunderDel(const std::string &name) { - py::object operation = dunderGetItem(name); - erase(py::cast(operation)); + nb::object operation = dunderGetItem(name); + erase(nb::cast(operation)); } MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { @@ -2079,7 +2104,7 @@ MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { MlirAttribute symbolAttr = mlirOperationGetAttributeByName( symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); if (mlirAttributeIsNull(symbolAttr)) - throw py::value_error("Expected operation to have a symbol name."); + throw nb::value_error("Expected operation to have a symbol name."); return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); } @@ -2091,7 +2116,7 @@ MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { MlirAttribute existingNameAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) - throw py::value_error("Expected operation to have a symbol name."); + throw nb::value_error("Expected operation to have a symbol name."); return existingNameAttr; } @@ -2104,7 +2129,7 @@ void PySymbolTable::setSymbolName(PyOperationBase &symbol, MlirAttribute existingNameAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) - throw py::value_error("Expected operation to have a symbol name."); + throw nb::value_error("Expected operation to have a symbol name."); MlirAttribute newNameAttr = mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); @@ -2117,7 +2142,7 @@ MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { MlirAttribute existingVisAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) - throw py::value_error("Expected operation to have a symbol visibility."); + throw nb::value_error("Expected operation to have a symbol visibility."); return existingVisAttr; } @@ -2125,7 +2150,7 @@ void PySymbolTable::setVisibility(PyOperationBase &symbol, const std::string &visibility) { if (visibility != "public" && visibility != "private" && visibility != "nested") - throw py::value_error( + throw nb::value_error( "Expected visibility to be 'public', 'private' or 'nested'"); PyOperation &operation = symbol.getOperation(); operation.checkValid(); @@ -2133,7 +2158,7 @@ void PySymbolTable::setVisibility(PyOperationBase &symbol, MlirAttribute existingVisAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) - throw py::value_error("Expected operation to have a symbol visibility."); + throw nb::value_error("Expected operation to have a symbol visibility."); MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(visibility)); mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); @@ -2148,20 +2173,20 @@ void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), from.getOperation()))) - throw py::value_error("Symbol rename failed"); + throw nb::value_error("Symbol rename failed"); } void PySymbolTable::walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, - py::object callback) { + nb::object callback) { PyOperation &fromOperation = from.getOperation(); fromOperation.checkValid(); struct UserData { PyMlirContextRef context; - py::object callback; + nb::object callback; bool gotException; std::string exceptionWhat; - py::object exceptionType; + nb::object exceptionType; }; UserData userData{ fromOperation.getContext(), std::move(callback), false, {}, {}}; @@ -2175,10 +2200,10 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from, return; try { calleeUserData->callback(pyFoundOp.getObject(), isVisible); - } catch (py::error_already_set &e) { + } catch (nb::python_error &e) { calleeUserData->gotException = true; calleeUserData->exceptionWhat = e.what(); - calleeUserData->exceptionType = e.type(); + calleeUserData->exceptionType = nb::borrow(e.type()); } }, static_cast(&userData)); @@ -2200,7 +2225,7 @@ class PyConcreteValue : public PyValue { // IsAFunctionTy isaFunction // const char *pyClassName // and redefine bindDerived. - using ClassTy = py::class_; + using ClassTy = nb::class_; using IsAFunctionTy = bool (*)(MlirValue); PyConcreteValue() = default; @@ -2213,25 +2238,26 @@ class PyConcreteValue : public PyValue { /// type mismatches. static MlirValue castFrom(PyValue &orig) { if (!DerivedTy::isaFunction(orig.get())) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw py::value_error((Twine("Cannot cast value to ") + + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast value to ") + DerivedTy::pyClassName + " (from " + origRepr + ")") - .str()); + .str() + .c_str()); } return orig.get(); } /// Binds the Python module objects to functions of this class. - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); - cls.def(py::init(), py::keep_alive<0, 1>(), py::arg("value")); + static void bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); cls.def_static( "isinstance", [](PyValue &otherValue) -> bool { return DerivedTy::isaFunction(otherValue); }, - py::arg("other_value")); + nb::arg("other_value")); cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](DerivedTy &self) { return self.maybeDownCast(); }); DerivedTy::bindDerived(cls); @@ -2249,11 +2275,11 @@ class PyBlockArgument : public PyConcreteValue { using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_property_readonly("owner", [](PyBlockArgument &self) { + c.def_prop_ro("owner", [](PyBlockArgument &self) { return PyBlock(self.getParentOperation(), mlirBlockArgumentGetOwner(self.get())); }); - c.def_property_readonly("arg_number", [](PyBlockArgument &self) { + c.def_prop_ro("arg_number", [](PyBlockArgument &self) { return mlirBlockArgumentGetArgNumber(self.get()); }); c.def( @@ -2261,7 +2287,7 @@ class PyBlockArgument : public PyConcreteValue { [](PyBlockArgument &self, PyType type) { return mlirBlockArgumentSetType(self.get(), type); }, - py::arg("type")); + nb::arg("type")); } }; @@ -2273,14 +2299,14 @@ class PyOpResult : public PyConcreteValue { using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_property_readonly("owner", [](PyOpResult &self) { + c.def_prop_ro("owner", [](PyOpResult &self) { assert( mlirOperationEqual(self.getParentOperation()->get(), mlirOpResultGetOwner(self.get())) && "expected the owner of the value in Python to match that in the IR"); return self.getParentOperation().getObject(); }); - c.def_property_readonly("result_number", [](PyOpResult &self) { + c.def_prop_ro("result_number", [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); }); } @@ -2317,7 +2343,7 @@ class PyBlockArgumentList operation(std::move(operation)), block(block) {} static void bindDerived(ClassTy &c) { - c.def_property_readonly("types", [](PyBlockArgumentList &self) { + c.def_prop_ro("types", [](PyBlockArgumentList &self) { return getValueTypes(self, self.operation->getContext()); }); } @@ -2422,10 +2448,10 @@ class PyOpResultList : public Sliceable { operation(std::move(operation)) {} static void bindDerived(ClassTy &c) { - c.def_property_readonly("types", [](PyOpResultList &self) { + c.def_prop_ro("types", [](PyOpResultList &self) { return getValueTypes(self, self.operation->getContext()); }); - c.def_property_readonly("owner", [](PyOpResultList &self) { + c.def_prop_ro("owner", [](PyOpResultList &self) { return self.operation->createOpView(); }); } @@ -2508,14 +2534,14 @@ class PyOpAttributeMap { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { - throw py::key_error("attempt to access a non-existent attribute"); + throw nb::key_error("attempt to access a non-existent attribute"); } return attr; } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { if (index < 0 || index >= dunderLen()) { - throw py::index_error("attempt to access out of bounds attribute"); + throw nb::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation->get(), index); @@ -2534,7 +2560,7 @@ class PyOpAttributeMap { int removed = mlirOperationRemoveAttributeByName(operation->get(), toMlirStringRef(name)); if (!removed) - throw py::key_error("attempt to delete a non-existent attribute"); + throw nb::key_error("attempt to delete a non-existent attribute"); } intptr_t dunderLen() { @@ -2546,8 +2572,8 @@ class PyOpAttributeMap { operation->get(), toMlirStringRef(name))); } - static void bind(py::module &m) { - py::class_(m, "OpAttributeMap", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OpAttributeMap") .def("__contains__", &PyOpAttributeMap::dunderContains) .def("__len__", &PyOpAttributeMap::dunderLen) .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) @@ -2566,21 +2592,21 @@ class PyOpAttributeMap { // Populates the core exports of the 'ir' submodule. //------------------------------------------------------------------------------ -void mlir::python::populateIRCore(py::module &m) { +void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Enums. //---------------------------------------------------------------------------- - py::enum_(m, "DiagnosticSeverity", py::module_local()) + nb::enum_(m, "DiagnosticSeverity") .value("ERROR", MlirDiagnosticError) .value("WARNING", MlirDiagnosticWarning) .value("NOTE", MlirDiagnosticNote) .value("REMARK", MlirDiagnosticRemark); - py::enum_(m, "WalkOrder", py::module_local()) + nb::enum_(m, "WalkOrder") .value("PRE_ORDER", MlirWalkPreOrder) .value("POST_ORDER", MlirWalkPostOrder); - py::enum_(m, "WalkResult", py::module_local()) + nb::enum_(m, "WalkResult") .value("ADVANCE", MlirWalkResultAdvance) .value("INTERRUPT", MlirWalkResultInterrupt) .value("SKIP", MlirWalkResultSkip); @@ -2588,33 +2614,37 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Diagnostics. //---------------------------------------------------------------------------- - py::class_(m, "Diagnostic", py::module_local()) - .def_property_readonly("severity", &PyDiagnostic::getSeverity) - .def_property_readonly("location", &PyDiagnostic::getLocation) - .def_property_readonly("message", &PyDiagnostic::getMessage) - .def_property_readonly("notes", &PyDiagnostic::getNotes) - .def("__str__", [](PyDiagnostic &self) -> py::str { + nb::class_(m, "Diagnostic") + .def_prop_ro("severity", &PyDiagnostic::getSeverity) + .def_prop_ro("location", &PyDiagnostic::getLocation) + .def_prop_ro("message", &PyDiagnostic::getMessage) + .def_prop_ro("notes", &PyDiagnostic::getNotes) + .def("__str__", [](PyDiagnostic &self) -> nb::str { if (!self.isValid()) - return ""; + return nb::str(""); return self.getMessage(); }); - py::class_(m, "DiagnosticInfo", - py::module_local()) - .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); })) - .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity) - .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location) - .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message) - .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes) + nb::class_(m, "DiagnosticInfo") + .def("__init__", + [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { + new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); + }) + .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity) + .def_ro("location", &PyDiagnostic::DiagnosticInfo::location) + .def_ro("message", &PyDiagnostic::DiagnosticInfo::message) + .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes) .def("__str__", [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); - py::class_(m, "DiagnosticHandler", py::module_local()) + nb::class_(m, "DiagnosticHandler") .def("detach", &PyDiagnosticHandler::detach) - .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) - .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) + .def_prop_ro("attached", &PyDiagnosticHandler::isAttached) + .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError) .def("__enter__", &PyDiagnosticHandler::contextEnter) - .def("__exit__", &PyDiagnosticHandler::contextExit); + .def("__exit__", &PyDiagnosticHandler::contextExit, + nb::arg("exc_type").none(), nb::arg("exc_value").none(), + nb::arg("traceback").none()); //---------------------------------------------------------------------------- // Mapping of MlirContext. @@ -2622,8 +2652,12 @@ void mlir::python::populateIRCore(py::module &m) { // __init__.py will subclass it with site-specific functionality and set a // "Context" attribute on this module. //---------------------------------------------------------------------------- - py::class_(m, "_BaseContext", py::module_local()) - .def(py::init<>(&PyMlirContext::createNewContextForInit)) + nb::class_(m, "_BaseContext") + .def("__init__", + [](PyMlirContext &self) { + MlirContext context = mlirContextCreateWithThreading(false); + new (&self) PyMlirContext(context); + }) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", [](PyMlirContext &self) { @@ -2635,28 +2669,28 @@ void mlir::python::populateIRCore(py::module &m) { &PyMlirContext::getLiveOperationObjects) .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) .def("_clear_live_operations_inside", - py::overload_cast( + nb::overload_cast( &PyMlirContext::clearOperationsInside)) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyMlirContext::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) - .def("__exit__", &PyMlirContext::contextExit) - .def_property_readonly_static( + .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), + nb::arg("exc_value").none(), nb::arg("traceback").none()) + .def_prop_ro_static( "current", - [](py::object & /*class*/) { + [](nb::object & /*class*/) { auto *context = PyThreadContextEntry::getDefaultContext(); if (!context) - return py::none().cast(); - return py::cast(context); + return nb::none(); + return nb::cast(context); }, "Gets the Context bound to the current thread or raises ValueError") - .def_property_readonly( + .def_prop_ro( "dialects", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, "Gets a container for accessing dialects by name") - .def_property_readonly( + .def_prop_ro( "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, "Alias for 'dialect'") .def( @@ -2665,14 +2699,14 @@ void mlir::python::populateIRCore(py::module &m) { MlirDialect dialect = mlirContextGetOrLoadDialect( self.get(), {name.data(), name.size()}); if (mlirDialectIsNull(dialect)) { - throw py::value_error( - (Twine("Dialect '") + name + "' not found").str()); + throw nb::value_error( + (Twine("Dialect '") + name + "' not found").str().c_str()); } return PyDialectDescriptor(self.getRef(), dialect); }, - py::arg("dialect_name"), + nb::arg("dialect_name"), "Gets or loads a dialect by name, returning its descriptor object") - .def_property( + .def_prop_rw( "allow_unregistered_dialects", [](PyMlirContext &self) -> bool { return mlirContextGetAllowUnregisteredDialects(self.get()); @@ -2681,32 +2715,32 @@ void mlir::python::populateIRCore(py::module &m) { mlirContextSetAllowUnregisteredDialects(self.get(), value); }) .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, - py::arg("callback"), + nb::arg("callback"), "Attaches a diagnostic handler that will receive callbacks") .def( "enable_multithreading", [](PyMlirContext &self, bool enable) { mlirContextEnableMultithreading(self.get(), enable); }, - py::arg("enable")) + nb::arg("enable")) .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { return mlirContextIsRegisteredOperation( self.get(), MlirStringRef{name.data(), name.size()}); }, - py::arg("operation_name")) + nb::arg("operation_name")) .def( "append_dialect_registry", [](PyMlirContext &self, PyDialectRegistry ®istry) { mlirContextAppendDialectRegistry(self.get(), registry); }, - py::arg("registry")) - .def_property("emit_error_diagnostics", nullptr, - &PyMlirContext::setEmitErrorDiagnostics, - "Emit error diagnostics to diagnostic handlers. By default " - "error diagnostics are captured and reported through " - "MLIRError exceptions.") + nb::arg("registry")) + .def_prop_rw("emit_error_diagnostics", nullptr, + &PyMlirContext::setEmitErrorDiagnostics, + "Emit error diagnostics to diagnostic handlers. By default " + "error diagnostics are captured and reported through " + "MLIRError exceptions.") .def("load_all_available_dialects", [](PyMlirContext &self) { mlirContextLoadAllAvailableDialects(self.get()); }); @@ -2714,13 +2748,12 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor //---------------------------------------------------------------------------- - py::class_(m, "DialectDescriptor", py::module_local()) - .def_property_readonly("namespace", - [](PyDialectDescriptor &self) { - MlirStringRef ns = - mlirDialectGetNamespace(self.get()); - return py::str(ns.data, ns.length); - }) + nb::class_(m, "DialectDescriptor") + .def_prop_ro("namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + return nb::str(ns.data, ns.length); + }) .def("__repr__", [](PyDialectDescriptor &self) { MlirStringRef ns = mlirDialectGetNamespace(self.get()); std::string repr("(m, "Dialects", py::module_local()) + nb::class_(m, "Dialects") .def("__getitem__", [=](PyDialects &self, std::string keyName) { MlirDialect dialect = self.getDialectForKey(keyName, /*attrError=*/false); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); return createCustomDialectWrapper(keyName, std::move(descriptor)); }) .def("__getattr__", [=](PyDialects &self, std::string attrName) { MlirDialect dialect = self.getDialectForKey(attrName, /*attrError=*/true); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); return createCustomDialectWrapper(attrName, std::move(descriptor)); }); //---------------------------------------------------------------------------- // Mapping of PyDialect //---------------------------------------------------------------------------- - py::class_(m, "Dialect", py::module_local()) - .def(py::init(), py::arg("descriptor")) - .def_property_readonly( - "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) - .def("__repr__", [](py::object self) { + nb::class_(m, "Dialect") + .def(nb::init(), nb::arg("descriptor")) + .def_prop_ro("descriptor", + [](PyDialect &self) { return self.getDescriptor(); }) + .def("__repr__", [](nb::object self) { auto clazz = self.attr("__class__"); - return py::str(""); + return nb::str(""); }); //---------------------------------------------------------------------------- // Mapping of PyDialectRegistry //---------------------------------------------------------------------------- - py::class_(m, "DialectRegistry", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyDialectRegistry::getCapsule) + nb::class_(m, "DialectRegistry") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) - .def(py::init<>()); + .def(nb::init<>()); //---------------------------------------------------------------------------- // Mapping of Location //---------------------------------------------------------------------------- - py::class_(m, "Location", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) + nb::class_(m, "Location") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) .def("__enter__", &PyLocation::contextEnter) - .def("__exit__", &PyLocation::contextExit) + .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), + nb::arg("exc_value").none(), nb::arg("traceback").none()) .def("__eq__", [](PyLocation &self, PyLocation &other) -> bool { return mlirLocationEqual(self, other); }) - .def("__eq__", [](PyLocation &self, py::object other) { return false; }) - .def_property_readonly_static( + .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) + .def_prop_ro_static( "current", - [](py::object & /*class*/) { + [](nb::object & /*class*/) { auto *loc = PyThreadContextEntry::getDefaultLocation(); if (!loc) - throw py::value_error("No current Location"); + throw nb::value_error("No current Location"); return loc; }, "Gets the Location bound to the current thread or raises ValueError") @@ -2801,14 +2834,14 @@ void mlir::python::populateIRCore(py::module &m) { return PyLocation(context->getRef(), mlirLocationUnknownGet(context->get())); }, - py::arg("context") = py::none(), + nb::arg("context").none() = nb::none(), "Gets a Location representing an unknown location") .def_static( "callsite", [](PyLocation callee, const std::vector &frames, DefaultingPyMlirContext context) { if (frames.empty()) - throw py::value_error("No caller frames provided"); + throw nb::value_error("No caller frames provided"); MlirLocation caller = frames.back().get(); for (const PyLocation &frame : llvm::reverse(llvm::ArrayRef(frames).drop_back())) @@ -2816,7 +2849,8 @@ void mlir::python::populateIRCore(py::module &m) { return PyLocation(context->getRef(), mlirLocationCallSiteGet(callee.get(), caller)); }, - py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), + nb::arg("callee"), nb::arg("frames"), + nb::arg("context").none() = nb::none(), kContextGetCallSiteLocationDocstring) .def_static( "file", @@ -2827,8 +2861,9 @@ void mlir::python::populateIRCore(py::module &m) { mlirLocationFileLineColGet( context->get(), toMlirStringRef(filename), line, col)); }, - py::arg("filename"), py::arg("line"), py::arg("col"), - py::arg("context") = py::none(), kContextGetFileLocationDocstring) + nb::arg("filename"), nb::arg("line"), nb::arg("col"), + nb::arg("context").none() = nb::none(), + kContextGetFileLocationDocstring) .def_static( "fused", [](const std::vector &pyLocations, @@ -2843,8 +2878,9 @@ void mlir::python::populateIRCore(py::module &m) { metadata ? metadata->get() : MlirAttribute{0}); return PyLocation(context->getRef(), location); }, - py::arg("locations"), py::arg("metadata") = py::none(), - py::arg("context") = py::none(), kContextGetFusedLocationDocstring) + nb::arg("locations"), nb::arg("metadata").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kContextGetFusedLocationDocstring) .def_static( "name", [](std::string name, std::optional childLoc, @@ -2856,21 +2892,22 @@ void mlir::python::populateIRCore(py::module &m) { childLoc ? childLoc->get() : mlirLocationUnknownGet(context->get()))); }, - py::arg("name"), py::arg("childLoc") = py::none(), - py::arg("context") = py::none(), kContextGetNameLocationDocString) + nb::arg("name"), nb::arg("childLoc").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kContextGetNameLocationDocString) .def_static( "from_attr", [](PyAttribute &attribute, DefaultingPyMlirContext context) { return PyLocation(context->getRef(), mlirLocationFromAttribute(attribute)); }, - py::arg("attribute"), py::arg("context") = py::none(), + nb::arg("attribute"), nb::arg("context").none() = nb::none(), "Gets a Location from a LocationAttr") - .def_property_readonly( + .def_prop_ro( "context", [](PyLocation &self) { return self.getContext().getObject(); }, "Context that owns the Location") - .def_property_readonly( + .def_prop_ro( "attr", [](PyLocation &self) { return mlirLocationGetAttribute(self); }, "Get the underlying LocationAttr") @@ -2879,7 +2916,7 @@ void mlir::python::populateIRCore(py::module &m) { [](PyLocation &self, std::string message) { mlirEmitError(self, message.c_str()); }, - py::arg("message"), "Emits an error at this location") + nb::arg("message"), "Emits an error at this location") .def("__repr__", [](PyLocation &self) { PyPrintAccumulator printAccum; mlirLocationPrint(self, printAccum.getCallback(), @@ -2890,8 +2927,8 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Module //---------------------------------------------------------------------------- - py::class_(m, "Module", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) + nb::class_(m, "Module", nb::is_weak_referenceable()) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) .def_static( "parse", @@ -2903,7 +2940,19 @@ void mlir::python::populateIRCore(py::module &m) { throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, - py::arg("asm"), py::arg("context") = py::none(), + nb::arg("asm"), nb::arg("context").none() = nb::none(), + kModuleParseDocstring) + .def_static( + "parse", + [](nb::bytes moduleAsm, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirModule module = mlirModuleCreateParse( + context->get(), toMlirStringRef(moduleAsm)); + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("asm"), nb::arg("context").none() = nb::none(), kModuleParseDocstring) .def_static( "create", @@ -2911,12 +2960,12 @@ void mlir::python::populateIRCore(py::module &m) { MlirModule module = mlirModuleCreateEmpty(loc); return PyModule::forModule(module).releaseObject(); }, - py::arg("loc") = py::none(), "Creates an empty module") - .def_property_readonly( + nb::arg("loc").none() = nb::none(), "Creates an empty module") + .def_prop_ro( "context", [](PyModule &self) { return self.getContext().getObject(); }, "Context that created the Module") - .def_property_readonly( + .def_prop_ro( "operation", [](PyModule &self) { return PyOperation::forOperation(self.getContext(), @@ -2925,7 +2974,7 @@ void mlir::python::populateIRCore(py::module &m) { .releaseObject(); }, "Accesses the module as an operation") - .def_property_readonly( + .def_prop_ro( "body", [](PyModule &self) { PyOperationRef moduleOp = PyOperation::forOperation( @@ -2943,7 +2992,7 @@ void mlir::python::populateIRCore(py::module &m) { kDumpDocstring) .def( "__str__", - [](py::object self) { + [](nb::object self) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, @@ -2952,27 +3001,26 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Operation. //---------------------------------------------------------------------------- - py::class_(m, "_OperationBase", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - [](PyOperationBase &self) { - return self.getOperation().getCapsule(); - }) + nb::class_(m, "_OperationBase") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, + [](PyOperationBase &self) { + return self.getOperation().getCapsule(); + }) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { return &self.getOperation() == &other.getOperation(); }) .def("__eq__", - [](PyOperationBase &self, py::object other) { return false; }) + [](PyOperationBase &self, nb::object other) { return false; }) .def("__hash__", [](PyOperationBase &self) { return static_cast(llvm::hash_value(&self.getOperation())); }) - .def_property_readonly("attributes", - [](PyOperationBase &self) { - return PyOpAttributeMap( - self.getOperation().getRef()); - }) - .def_property_readonly( + .def_prop_ro("attributes", + [](PyOperationBase &self) { + return PyOpAttributeMap(self.getOperation().getRef()); + }) + .def_prop_ro( "context", [](PyOperationBase &self) { PyOperation &concreteOperation = self.getOperation(); @@ -2980,46 +3028,44 @@ void mlir::python::populateIRCore(py::module &m) { return concreteOperation.getContext().getObject(); }, "Context that owns the Operation") - .def_property_readonly("name", - [](PyOperationBase &self) { - auto &concreteOperation = self.getOperation(); - concreteOperation.checkValid(); - MlirOperation operation = - concreteOperation.get(); - MlirStringRef name = mlirIdentifierStr( - mlirOperationGetName(operation)); - return py::str(name.data, name.length); - }) - .def_property_readonly("operands", - [](PyOperationBase &self) { - return PyOpOperandList( - self.getOperation().getRef()); - }) - .def_property_readonly("regions", - [](PyOperationBase &self) { - return PyRegionList( - self.getOperation().getRef()); - }) - .def_property_readonly( + .def_prop_ro("name", + [](PyOperationBase &self) { + auto &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + MlirOperation operation = concreteOperation.get(); + MlirStringRef name = + mlirIdentifierStr(mlirOperationGetName(operation)); + return nb::str(name.data, name.length); + }) + .def_prop_ro("operands", + [](PyOperationBase &self) { + return PyOpOperandList(self.getOperation().getRef()); + }) + .def_prop_ro("regions", + [](PyOperationBase &self) { + return PyRegionList(self.getOperation().getRef()); + }) + .def_prop_ro( "results", [](PyOperationBase &self) { return PyOpResultList(self.getOperation().getRef()); }, "Returns the list of Operation results.") - .def_property_readonly( + .def_prop_ro( "result", [](PyOperationBase &self) { auto &operation = self.getOperation(); auto numResults = mlirOperationGetNumResults(operation); if (numResults != 1) { auto name = mlirIdentifierStr(mlirOperationGetName(operation)); - throw py::value_error( + throw nb::value_error( (Twine("Cannot call .result on operation ") + StringRef(name.data, name.length) + " which has " + Twine(numResults) + " results (it is only valid for operations with a " "single result)") - .str()); + .str() + .c_str()); } return PyOpResult(operation.getRef(), mlirOperationGetResult(operation, 0)) @@ -3027,7 +3073,7 @@ void mlir::python::populateIRCore(py::module &m) { }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") - .def_property_readonly( + .def_prop_ro( "location", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); @@ -3036,14 +3082,13 @@ void mlir::python::populateIRCore(py::module &m) { }, "Returns the source location the operation was defined or derived " "from.") - .def_property_readonly("parent", - [](PyOperationBase &self) -> py::object { - auto parent = - self.getOperation().getParentOperation(); - if (parent) - return parent->getObject(); - return py::none(); - }) + .def_prop_ro("parent", + [](PyOperationBase &self) -> nb::object { + auto parent = self.getOperation().getParentOperation(); + if (parent) + return parent->getObject(); + return nb::none(); + }) .def( "__str__", [](PyOperationBase &self) { @@ -3058,75 +3103,76 @@ void mlir::python::populateIRCore(py::module &m) { }, "Returns the assembly form of the operation.") .def("print", - py::overload_cast( + nb::overload_cast( &PyOperationBase::print), - py::arg("state"), py::arg("file") = py::none(), - py::arg("binary") = false, kOperationPrintStateDocstring) + nb::arg("state"), nb::arg("file").none() = nb::none(), + nb::arg("binary") = false, kOperationPrintStateDocstring) .def("print", - py::overload_cast, bool, bool, bool, bool, - bool, py::object, bool, bool>( + nb::overload_cast, bool, bool, bool, bool, + bool, nb::object, bool, bool>( &PyOperationBase::print), // Careful: Lots of arguments must match up with print method. - py::arg("large_elements_limit") = py::none(), - py::arg("enable_debug_info") = false, - py::arg("pretty_debug_info") = false, - py::arg("print_generic_op_form") = false, - py::arg("use_local_scope") = false, - py::arg("assume_verified") = false, py::arg("file") = py::none(), - py::arg("binary") = false, py::arg("skip_regions") = false, - kOperationPrintDocstring) - .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), - py::arg("desired_version") = py::none(), + nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("enable_debug_info") = false, + nb::arg("pretty_debug_info") = false, + nb::arg("print_generic_op_form") = false, + nb::arg("use_local_scope") = false, + nb::arg("assume_verified") = false, + nb::arg("file").none() = nb::none(), nb::arg("binary") = false, + nb::arg("skip_regions") = false, kOperationPrintDocstring) + .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), + nb::arg("desired_version").none() = nb::none(), kOperationPrintBytecodeDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. - py::arg("binary") = false, - py::arg("large_elements_limit") = py::none(), - py::arg("enable_debug_info") = false, - py::arg("pretty_debug_info") = false, - py::arg("print_generic_op_form") = false, - py::arg("use_local_scope") = false, - py::arg("assume_verified") = false, py::arg("skip_regions") = false, + nb::arg("binary") = false, + nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("enable_debug_info") = false, + nb::arg("pretty_debug_info") = false, + nb::arg("print_generic_op_form") = false, + nb::arg("use_local_scope") = false, + nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, kOperationGetAsmDocstring) .def("verify", &PyOperationBase::verify, "Verify the operation. Raises MLIRError if verification fails, and " "returns true otherwise.") - .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), + .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"), "Puts self immediately after the other operation in its parent " "block.") - .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), + .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"), "Puts self immediately before the other operation in its parent " "block.") .def( "clone", - [](PyOperationBase &self, py::object ip) { + [](PyOperationBase &self, nb::object ip) { return self.getOperation().clone(ip); }, - py::arg("ip") = py::none()) + nb::arg("ip").none() = nb::none()) .def( "detach_from_parent", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); operation.checkValid(); if (!operation.isAttached()) - throw py::value_error("Detached operation has no parent."); + throw nb::value_error("Detached operation has no parent."); operation.detachFromParent(); return operation.createOpView(); }, "Detaches the operation from its parent block.") .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) - .def("walk", &PyOperationBase::walk, py::arg("callback"), - py::arg("walk_order") = MlirWalkPostOrder); - - py::class_(m, "Operation", py::module_local()) - .def_static("create", &PyOperation::create, py::arg("name"), - py::arg("results") = py::none(), - py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = 0, - py::arg("loc") = py::none(), py::arg("ip") = py::none(), - py::arg("infer_type") = false, kOperationCreateDocstring) + .def("walk", &PyOperationBase::walk, nb::arg("callback"), + nb::arg("walk_order") = MlirWalkPostOrder); + + nb::class_(m, "Operation") + .def_static("create", &PyOperation::create, nb::arg("name"), + nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), + nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(), + nb::arg("ip").none() = nb::none(), + nb::arg("infer_type") = false, kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, @@ -3134,16 +3180,15 @@ void mlir::python::populateIRCore(py::module &m) { return PyOperation::parse(context->getRef(), sourceStr, sourceName) ->createOpView(); }, - py::arg("source"), py::kw_only(), py::arg("source_name") = "", - py::arg("context") = py::none(), + nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", + nb::arg("context").none() = nb::none(), "Parses an operation. Supports both text assembly format and binary " "bytecode format.") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyOperation::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) - .def_property_readonly("operation", [](py::object self) { return self; }) - .def_property_readonly("opview", &PyOperation::createOpView) - .def_property_readonly( + .def_prop_ro("operation", [](nb::object self) { return self; }) + .def_prop_ro("opview", &PyOperation::createOpView) + .def_prop_ro( "successors", [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); @@ -3151,30 +3196,33 @@ void mlir::python::populateIRCore(py::module &m) { "Returns the list of Operation successors."); auto opViewClass = - py::class_(m, "OpView", py::module_local()) - .def(py::init(), py::arg("operation")) - .def_property_readonly("operation", &PyOpView::getOperationObject) - .def_property_readonly("opview", [](py::object self) { return self; }) + nb::class_(m, "OpView") + .def(nb::init(), nb::arg("operation")) + .def_prop_ro("operation", &PyOpView::getOperationObject) + .def_prop_ro("opview", [](nb::object self) { return self; }) .def( "__str__", - [](PyOpView &self) { return py::str(self.getOperationObject()); }) - .def_property_readonly( + [](PyOpView &self) { return nb::str(self.getOperationObject()); }) + .def_prop_ro( "successors", [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); }, "Returns the list of Operation successors."); - opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); - opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); - opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); + opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); + opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); + opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); opViewClass.attr("build_generic") = classmethod( - &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), - py::arg("operands") = py::none(), py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = py::none(), - py::arg("loc") = py::none(), py::arg("ip") = py::none(), + &PyOpView::buildGeneric, nb::arg("cls"), + nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), "Builds a specific, generated OpView based on class level attributes."); opViewClass.attr("parse") = classmethod( - [](const py::object &cls, const std::string &sourceStr, + [](const nb::object &cls, const std::string &sourceStr, const std::string &sourceName, DefaultingPyMlirContext context) { PyOperationRef parsed = PyOperation::parse(context->getRef(), sourceStr, sourceName); @@ -3185,30 +3233,30 @@ void mlir::python::populateIRCore(py::module &m) { // `OpView` subclasses, and is not intended to be used on `OpView` // directly. std::string clsOpName = - py::cast(cls.attr("OPERATION_NAME")); + nb::cast(cls.attr("OPERATION_NAME")); MlirStringRef identifier = mlirIdentifierStr(mlirOperationGetName(*parsed.get())); std::string_view parsedOpName(identifier.data, identifier.length); if (clsOpName != parsedOpName) throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" + parsedOpName + "'"); - return PyOpView::constructDerived(cls, *parsed.get()); + return PyOpView::constructDerived(cls, parsed.getObject()); }, - py::arg("cls"), py::arg("source"), py::kw_only(), - py::arg("source_name") = "", py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("source"), nb::kw_only(), + nb::arg("source_name") = "", nb::arg("context").none() = nb::none(), "Parses a specific, generated OpView based on class level attributes"); //---------------------------------------------------------------------------- // Mapping of PyRegion. //---------------------------------------------------------------------------- - py::class_(m, "Region", py::module_local()) - .def_property_readonly( + nb::class_(m, "Region") + .def_prop_ro( "blocks", [](PyRegion &self) { return PyBlockList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of blocks.") - .def_property_readonly( + .def_prop_ro( "owner", [](PyRegion &self) { return self.getParentOperation()->createOpView(); @@ -3226,27 +3274,27 @@ void mlir::python::populateIRCore(py::module &m) { [](PyRegion &self, PyRegion &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); + .def("__eq__", [](PyRegion &self, nb::object &other) { return false; }); //---------------------------------------------------------------------------- // Mapping of PyBlock. //---------------------------------------------------------------------------- - py::class_(m, "Block", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) - .def_property_readonly( + nb::class_(m, "Block") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) + .def_prop_ro( "owner", [](PyBlock &self) { return self.getParentOperation()->createOpView(); }, "Returns the owning operation of this block.") - .def_property_readonly( + .def_prop_ro( "region", [](PyBlock &self) { MlirRegion region = mlirBlockGetParentRegion(self.get()); return PyRegion(self.getParentOperation(), region); }, "Returns the owning region of this block.") - .def_property_readonly( + .def_prop_ro( "arguments", [](PyBlock &self) { return PyBlockArgumentList(self.getParentOperation(), self.get()); @@ -3265,7 +3313,7 @@ void mlir::python::populateIRCore(py::module &m) { return mlirBlockEraseArgument(self.get(), index); }, "Erase the argument at 'index' and remove it from the argument list.") - .def_property_readonly( + .def_prop_ro( "operations", [](PyBlock &self) { return PyOperationList(self.getParentOperation(), self.get()); @@ -3273,15 +3321,15 @@ void mlir::python::populateIRCore(py::module &m) { "Returns a forward-optimized sequence of operations.") .def_static( "create_at_start", - [](PyRegion &parent, const py::list &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyRegion &parent, const nb::sequence &pyArgTypes, + const std::optional &pyArgLocs) { parent.checkValid(); MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, - py::arg("parent"), py::arg("arg_types") = py::list(), - py::arg("arg_locs") = std::nullopt, + nb::arg("parent"), nb::arg("arg_types") = nb::list(), + nb::arg("arg_locs") = std::nullopt, "Creates and returns a new Block at the beginning of the given " "region (with given argument types and locations).") .def( @@ -3295,28 +3343,32 @@ void mlir::python::populateIRCore(py::module &m) { "Append this block to a region, transferring ownership if necessary") .def( "create_before", - [](PyBlock &self, const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyBlock &self, const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockBefore(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - py::arg("arg_locs") = std::nullopt, + nb::arg("arg_types"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt, "Creates and returns a new Block before this block " "(with given argument types and locations).") .def( "create_after", - [](PyBlock &self, const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyBlock &self, const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockAfter(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - py::arg("arg_locs") = std::nullopt, + nb::arg("arg_types"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt, "Creates and returns a new Block after this block " "(with given argument types and locations).") .def( @@ -3333,7 +3385,7 @@ void mlir::python::populateIRCore(py::module &m) { [](PyBlock &self, PyBlock &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) + .def("__eq__", [](PyBlock &self, nb::object &other) { return false; }) .def("__hash__", [](PyBlock &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3359,7 +3411,7 @@ void mlir::python::populateIRCore(py::module &m) { operation.getOperation().setAttached( self.getParentOperation().getObject()); }, - py::arg("operation"), + nb::arg("operation"), "Appends an operation to this block. If the operation is currently " "in another block, it will be moved."); @@ -3367,39 +3419,41 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of PyInsertionPoint. //---------------------------------------------------------------------------- - py::class_(m, "InsertionPoint", py::module_local()) - .def(py::init(), py::arg("block"), + nb::class_(m, "InsertionPoint") + .def(nb::init(), nb::arg("block"), "Inserts after the last operation but still inside the block.") .def("__enter__", &PyInsertionPoint::contextEnter) - .def("__exit__", &PyInsertionPoint::contextExit) - .def_property_readonly_static( + .def("__exit__", &PyInsertionPoint::contextExit, + nb::arg("exc_type").none(), nb::arg("exc_value").none(), + nb::arg("traceback").none()) + .def_prop_ro_static( "current", - [](py::object & /*class*/) { + [](nb::object & /*class*/) { auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); if (!ip) - throw py::value_error("No current InsertionPoint"); + throw nb::value_error("No current InsertionPoint"); return ip; }, "Gets the InsertionPoint bound to the current thread or raises " "ValueError if none has been set") - .def(py::init(), py::arg("beforeOperation"), + .def(nb::init(), nb::arg("beforeOperation"), "Inserts before a referenced operation.") .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, - py::arg("block"), "Inserts at the beginning of the block.") + nb::arg("block"), "Inserts at the beginning of the block.") .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, - py::arg("block"), "Inserts before the block terminator.") - .def("insert", &PyInsertionPoint::insert, py::arg("operation"), + nb::arg("block"), "Inserts before the block terminator.") + .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), "Inserts an operation.") - .def_property_readonly( + .def_prop_ro( "block", [](PyInsertionPoint &self) { return self.getBlock(); }, "Returns the block that this InsertionPoint points to.") - .def_property_readonly( + .def_prop_ro( "ref_operation", - [](PyInsertionPoint &self) -> py::object { + [](PyInsertionPoint &self) -> nb::object { auto refOperation = self.getRefOperation(); if (refOperation) return refOperation->getObject(); - return py::none(); + return nb::none(); }, "The reference operation before which new operations are " "inserted, or None if the insertion point is at the end of " @@ -3408,13 +3462,12 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyAttribute. //---------------------------------------------------------------------------- - py::class_(m, "Attribute", py::module_local()) + nb::class_(m, "Attribute") // Delegate to the PyAttribute copy constructor, which will also lifetime // extend the backing context which owns the MlirAttribute. - .def(py::init(), py::arg("cast_from_type"), + .def(nb::init(), nb::arg("cast_from_type"), "Casts the passed attribute to the generic Attribute") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAttribute::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) .def_static( "parse", @@ -3426,24 +3479,24 @@ void mlir::python::populateIRCore(py::module &m) { throw MLIRError("Unable to parse attribute", errors.take()); return attr; }, - py::arg("asm"), py::arg("context") = py::none(), + nb::arg("asm"), nb::arg("context").none() = nb::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " "failure.") - .def_property_readonly( + .def_prop_ro( "context", [](PyAttribute &self) { return self.getContext().getObject(); }, "Context that owns the Attribute") - .def_property_readonly( - "type", [](PyAttribute &self) { return mlirAttributeGetType(self); }) + .def_prop_ro("type", + [](PyAttribute &self) { return mlirAttributeGetType(self); }) .def( "get_named", [](PyAttribute &self, std::string name) { return PyNamedAttribute(self, std::move(name)); }, - py::keep_alive<0, 1>(), "Binds a name to the attribute") + nb::keep_alive<0, 1>(), "Binds a name to the attribute") .def("__eq__", [](PyAttribute &self, PyAttribute &other) { return self == other; }) - .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) + .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; }) .def("__hash__", [](PyAttribute &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3474,36 +3527,35 @@ void mlir::python::populateIRCore(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) - .def_property_readonly( - "typeid", - [](PyAttribute &self) -> MlirTypeID { - MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); - assert(!mlirTypeIDIsNull(mlirTypeID) && - "mlirTypeID was expected to be non-null."); - return mlirTypeID; - }) + .def_prop_ro("typeid", + [](PyAttribute &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + return mlirTypeID; + }) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional typeCaster = + std::optional typeCaster = PyGlobals::get().lookupTypeCaster(mlirTypeID, mlirAttributeGetDialect(self)); if (!typeCaster) - return py::cast(self); + return nb::cast(self); return typeCaster.value()(self); }); //---------------------------------------------------------------------------- // Mapping of PyNamedAttribute //---------------------------------------------------------------------------- - py::class_(m, "NamedAttribute", py::module_local()) + nb::class_(m, "NamedAttribute") .def("__repr__", [](PyNamedAttribute &self) { PyPrintAccumulator printAccum; printAccum.parts.append("NamedAttribute("); printAccum.parts.append( - py::str(mlirIdentifierStr(self.namedAttr.name).data, + nb::str(mlirIdentifierStr(self.namedAttr.name).data, mlirIdentifierStr(self.namedAttr.name).length)); printAccum.parts.append("="); mlirAttributePrint(self.namedAttr.attribute, @@ -3512,28 +3564,28 @@ void mlir::python::populateIRCore(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) - .def_property_readonly( + .def_prop_ro( "name", [](PyNamedAttribute &self) { - return py::str(mlirIdentifierStr(self.namedAttr.name).data, + return nb::str(mlirIdentifierStr(self.namedAttr.name).data, mlirIdentifierStr(self.namedAttr.name).length); }, "The name of the NamedAttribute binding") - .def_property_readonly( + .def_prop_ro( "attr", [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, - py::keep_alive<0, 1>(), + nb::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); //---------------------------------------------------------------------------- // Mapping of PyType. //---------------------------------------------------------------------------- - py::class_(m, "Type", py::module_local()) + nb::class_(m, "Type") // Delegate to the PyType copy constructor, which will also lifetime // extend the backing context which owns the MlirType. - .def(py::init(), py::arg("cast_from_type"), + .def(nb::init(), nb::arg("cast_from_type"), "Casts the passed type to the generic Type") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) .def_static( "parse", @@ -3545,13 +3597,15 @@ void mlir::python::populateIRCore(py::module &m) { throw MLIRError("Unable to parse type", errors.take()); return type; }, - py::arg("asm"), py::arg("context") = py::none(), + nb::arg("asm"), nb::arg("context").none() = nb::none(), kContextParseTypeDocstring) - .def_property_readonly( + .def_prop_ro( "context", [](PyType &self) { return self.getContext().getObject(); }, "Context that owns the Type") .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) - .def("__eq__", [](PyType &self, py::object &other) { return false; }) + .def( + "__eq__", [](PyType &self, nb::object &other) { return false; }, + nb::arg("other").none()) .def("__hash__", [](PyType &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3585,28 +3639,27 @@ void mlir::python::populateIRCore(py::module &m) { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional typeCaster = + std::optional typeCaster = PyGlobals::get().lookupTypeCaster(mlirTypeID, mlirTypeGetDialect(self)); if (!typeCaster) - return py::cast(self); + return nb::cast(self); return typeCaster.value()(self); }) - .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID { + .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); if (!mlirTypeIDIsNull(mlirTypeID)) return mlirTypeID; - auto origRepr = - pybind11::repr(pybind11::cast(self)).cast(); - throw py::value_error( - (origRepr + llvm::Twine(" has no typeid.")).str()); + auto origRepr = nb::cast(nb::repr(nb::cast(self))); + throw nb::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); }); //---------------------------------------------------------------------------- // Mapping of PyTypeID. //---------------------------------------------------------------------------- - py::class_(m, "TypeID", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) + nb::class_(m, "TypeID") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) // Note, this tests whether the underlying TypeIDs are the same, // not whether the wrapper MlirTypeIDs are the same, nor whether @@ -3614,7 +3667,7 @@ void mlir::python::populateIRCore(py::module &m) { .def("__eq__", [](PyTypeID &self, PyTypeID &other) { return self == other; }) .def("__eq__", - [](PyTypeID &self, const py::object &other) { return false; }) + [](PyTypeID &self, const nb::object &other) { return false; }) // Note, this gives the hash value of the underlying TypeID, not the // hash value of the Python object, nor the hash value of the // MlirTypeID wrapper. @@ -3625,20 +3678,20 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- - py::class_(m, "Value", py::module_local()) - .def(py::init(), py::keep_alive<0, 1>(), py::arg("value")) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) + nb::class_(m, "Value") + .def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) - .def_property_readonly( + .def_prop_ro( "context", [](PyValue &self) { return self.getParentOperation()->getContext(); }, "Context in which the value lives.") .def( "dump", [](PyValue &self) { mlirValueDump(self.get()); }, kDumpDocstring) - .def_property_readonly( + .def_prop_ro( "owner", - [](PyValue &self) -> py::object { + [](PyValue &self) -> nb::object { MlirValue v = self.get(); if (mlirValueIsAOpResult(v)) { assert( @@ -3651,22 +3704,22 @@ void mlir::python::populateIRCore(py::module &m) { if (mlirValueIsABlockArgument(v)) { MlirBlock block = mlirBlockArgumentGetOwner(self.get()); - return py::cast(PyBlock(self.getParentOperation(), block)); + return nb::cast(PyBlock(self.getParentOperation(), block)); } assert(false && "Value must be a block argument or an op result"); - return py::none(); + return nb::none(); }) - .def_property_readonly("uses", - [](PyValue &self) { - return PyOpOperandIterator( - mlirValueGetFirstUse(self.get())); - }) + .def_prop_ro("uses", + [](PyValue &self) { + return PyOpOperandIterator( + mlirValueGetFirstUse(self.get())); + }) .def("__eq__", [](PyValue &self, PyValue &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyValue &self, py::object other) { return false; }) + .def("__eq__", [](PyValue &self, nb::object other) { return false; }) .def("__hash__", [](PyValue &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3698,26 +3751,26 @@ void mlir::python::populateIRCore(py::module &m) { mlirAsmStateDestroy(valueState); return printAccum.join(); }, - py::arg("use_local_scope") = false) + nb::arg("use_local_scope") = false) .def( "get_name", - [](PyValue &self, std::reference_wrapper state) { + [](PyValue &self, PyAsmState &state) { PyPrintAccumulator printAccum; - MlirAsmState valueState = state.get().get(); + MlirAsmState valueState = state.get(); mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(), printAccum.getUserData()); return printAccum.join(); }, - py::arg("state"), kGetNameAsOperand) - .def_property_readonly( - "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) + nb::arg("state"), kGetNameAsOperand) + .def_prop_ro("type", + [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( "set_type", [](PyValue &self, const PyType &type) { return mlirValueSetType(self.get(), type); }, - py::arg("type")) + nb::arg("type")) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { @@ -3730,22 +3783,22 @@ void mlir::python::populateIRCore(py::module &m) { MlirOperation exceptedUser = exception.get(); mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); }, - py::arg("with"), py::arg("exceptions"), + nb::arg("with"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", - [](MlirValue self, MlirValue with, py::list exceptions) { + [](MlirValue self, MlirValue with, nb::list exceptions) { // Convert Python list to a SmallVector of MlirOperations llvm::SmallVector exceptionOps; - for (py::handle exception : exceptions) { - exceptionOps.push_back(exception.cast().get()); + for (nb::handle exception : exceptions) { + exceptionOps.push_back(nb::cast(exception).get()); } mlirValueReplaceAllUsesExcept( self, with, static_cast(exceptionOps.size()), exceptionOps.data()); }, - py::arg("with"), py::arg("exceptions"), + nb::arg("with"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyValue &self) { return self.maybeDownCast(); }); @@ -3753,20 +3806,20 @@ void mlir::python::populateIRCore(py::module &m) { PyOpResult::bind(m); PyOpOperand::bind(m); - py::class_(m, "AsmState", py::module_local()) - .def(py::init(), py::arg("value"), - py::arg("use_local_scope") = false) - .def(py::init(), py::arg("op"), - py::arg("use_local_scope") = false); + nb::class_(m, "AsmState") + .def(nb::init(), nb::arg("value"), + nb::arg("use_local_scope") = false) + .def(nb::init(), nb::arg("op"), + nb::arg("use_local_scope") = false); //---------------------------------------------------------------------------- // Mapping of SymbolTable. //---------------------------------------------------------------------------- - py::class_(m, "SymbolTable", py::module_local()) - .def(py::init()) + nb::class_(m, "SymbolTable") + .def(nb::init()) .def("__getitem__", &PySymbolTable::dunderGetItem) - .def("insert", &PySymbolTable::insert, py::arg("operation")) - .def("erase", &PySymbolTable::erase, py::arg("operation")) + .def("insert", &PySymbolTable::insert, nb::arg("operation")) + .def("erase", &PySymbolTable::erase, nb::arg("operation")) .def("__delitem__", &PySymbolTable::dunderDel) .def("__contains__", [](PySymbolTable &table, const std::string &name) { @@ -3775,19 +3828,19 @@ void mlir::python::populateIRCore(py::module &m) { }) // Static helpers. .def_static("set_symbol_name", &PySymbolTable::setSymbolName, - py::arg("symbol"), py::arg("name")) + nb::arg("symbol"), nb::arg("name")) .def_static("get_symbol_name", &PySymbolTable::getSymbolName, - py::arg("symbol")) + nb::arg("symbol")) .def_static("get_visibility", &PySymbolTable::getVisibility, - py::arg("symbol")) + nb::arg("symbol")) .def_static("set_visibility", &PySymbolTable::setVisibility, - py::arg("symbol"), py::arg("visibility")) + nb::arg("symbol"), nb::arg("visibility")) .def_static("replace_all_symbol_uses", - &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), - py::arg("new_symbol"), py::arg("from_op")) + &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), + nb::arg("new_symbol"), nb::arg("from_op")) .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, - py::arg("from_op"), py::arg("all_sym_uses_visible"), - py::arg("callback")); + nb::arg("from_op"), nb::arg("all_sym_uses_visible"), + nb::arg("callback")); // Container bindings. PyBlockArgumentList::bind(m); @@ -3809,14 +3862,15 @@ void mlir::python::populateIRCore(py::module &m) { // Attribute builder getter. PyAttrBuilderMap::bind(m); - py::register_local_exception_translator([](std::exception_ptr p) { + nb::register_exception_translator([](const std::exception_ptr &p, + void *payload) { // We can't define exceptions with custom fields through pybind, so instead // the exception class is defined in python and imported here. try { if (p) std::rethrow_exception(p); } catch (const MLIRError &e) { - py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("MLIRError")(e.message, e.errorDiagnostics); PyErr_SetObject(PyExc_Exception, obj.ptr()); } diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 54cfa56066eb8b..c339a93e31857b 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include + #include #include -#include -#include -#include -#include #include #include #include @@ -24,7 +24,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -namespace py = pybind11; +namespace nb = nanobind; namespace mlir { namespace python { @@ -53,10 +53,10 @@ namespace { /// Takes in an optional ist of operands and converts them into a SmallVector /// of MlirVlaues. Returns an empty SmallVector if the list is empty. -llvm::SmallVector wrapOperands(std::optional operandList) { +llvm::SmallVector wrapOperands(std::optional operandList) { llvm::SmallVector mlirOperands; - if (!operandList || operandList->empty()) { + if (!operandList || operandList->size() == 0) { return mlirOperands; } @@ -68,40 +68,42 @@ llvm::SmallVector wrapOperands(std::optional operandList) { PyValue *val; try { - val = py::cast(it.value()); + val = nb::cast(it.value()); if (!val) - throw py::cast_error(); + throw nb::cast_error(); mlirOperands.push_back(val->get()); continue; - } catch (py::cast_error &err) { + } catch (nb::cast_error &err) { // Intentionally unhandled to try sequence below first. (void)err; } try { - auto vals = py::cast(it.value()); - for (py::object v : vals) { + auto vals = nb::cast(it.value()); + for (nb::handle v : vals) { try { - val = py::cast(v); + val = nb::cast(v); if (!val) - throw py::cast_error(); + throw nb::cast_error(); mlirOperands.push_back(val->get()); - } catch (py::cast_error &err) { - throw py::value_error( + } catch (nb::cast_error &err) { + throw nb::value_error( (llvm::Twine("Operand ") + llvm::Twine(it.index()) + " must be a Value or Sequence of Values (" + err.what() + ")") - .str()); + .str() + .c_str()); } } continue; - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " must be a Value or Sequence of Values (" + err.what() + ")") - .str()); + .str() + .c_str()); } - throw py::cast_error(); + throw nb::cast_error(); } return mlirOperands; @@ -144,24 +146,24 @@ wrapRegions(std::optional> regions) { template class PyConcreteOpInterface { protected: - using ClassTy = py::class_; + using ClassTy = nb::class_; using GetTypeIDFunctionTy = MlirTypeID (*)(); public: /// Constructs an interface instance from an object that is either an /// operation or a subclass of OpView. In the latter case, only the static /// methods of the interface are accessible to the caller. - PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) + PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context) : obj(std::move(object)) { try { - operation = &py::cast(obj); - } catch (py::cast_error &) { + operation = &nb::cast(obj); + } catch (nb::cast_error &) { // Do nothing. } try { - operation = &py::cast(obj).getOperation(); - } catch (py::cast_error &) { + operation = &nb::cast(obj).getOperation(); + } catch (nb::cast_error &) { // Do nothing. } @@ -169,7 +171,7 @@ class PyConcreteOpInterface { if (!mlirOperationImplementsInterface(*operation, ConcreteIface::getInterfaceID())) { std::string msg = "the operation does not implement "; - throw py::value_error(msg + ConcreteIface::pyClassName); + throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); } MlirIdentifier identifier = mlirOperationGetName(*operation); @@ -177,9 +179,9 @@ class PyConcreteOpInterface { opName = std::string(stringRef.data, stringRef.length); } else { try { - opName = obj.attr("OPERATION_NAME").template cast(); - } catch (py::cast_error &) { - throw py::type_error( + opName = nb::cast(obj.attr("OPERATION_NAME")); + } catch (nb::cast_error &) { + throw nb::type_error( "Op interface does not refer to an operation or OpView class"); } @@ -187,22 +189,19 @@ class PyConcreteOpInterface { mlirStringRefCreate(opName.data(), opName.length()), context.resolve().get(), ConcreteIface::getInterfaceID())) { std::string msg = "the operation does not implement "; - throw py::value_error(msg + ConcreteIface::pyClassName); + throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); } } } /// Creates the Python bindings for this class in the given module. - static void bind(py::module &m) { - py::class_ cls(m, ConcreteIface::pyClassName, - py::module_local()); - cls.def(py::init(), py::arg("object"), - py::arg("context") = py::none(), constructorDoc) - .def_property_readonly("operation", - &PyConcreteOpInterface::getOperationObject, - operationDoc) - .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, - opviewDoc); + static void bind(nb::module_ &m) { + nb::class_ cls(m, ConcreteIface::pyClassName); + cls.def(nb::init(), nb::arg("object"), + nb::arg("context").none() = nb::none(), constructorDoc) + .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject, + operationDoc) + .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc); ConcreteIface::bindDerived(cls); } @@ -216,9 +215,9 @@ class PyConcreteOpInterface { /// Returns the operation instance from which this object was constructed. /// Throws a type error if this object was constructed from a subclass of /// OpView. - py::object getOperationObject() { + nb::object getOperationObject() { if (operation == nullptr) { - throw py::type_error("Cannot get an operation from a static interface"); + throw nb::type_error("Cannot get an operation from a static interface"); } return operation->getRef().releaseObject(); @@ -227,9 +226,9 @@ class PyConcreteOpInterface { /// Returns the opview of the operation instance from which this object was /// constructed. Throws a type error if this object was constructed form a /// subclass of OpView. - py::object getOpView() { + nb::object getOpView() { if (operation == nullptr) { - throw py::type_error("Cannot get an opview from a static interface"); + throw nb::type_error("Cannot get an opview from a static interface"); } return operation->createOpView(); @@ -242,7 +241,7 @@ class PyConcreteOpInterface { private: PyOperation *operation = nullptr; std::string opName; - py::object obj; + nb::object obj; }; /// Python wrapper for InferTypeOpInterface. This interface has only static @@ -276,7 +275,7 @@ class PyInferTypeOpInterface /// Given the arguments required to build an operation, attempts to infer its /// return types. Throws value_error on failure. std::vector - inferReturnTypes(std::optional operandList, + inferReturnTypes(std::optional operandList, std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, @@ -299,7 +298,7 @@ class PyInferTypeOpInterface mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { - throw py::value_error("Failed to infer result types"); + throw nb::value_error("Failed to infer result types"); } return inferredTypes; @@ -307,11 +306,12 @@ class PyInferTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, - py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), - py::arg("properties") = py::none(), py::arg("regions") = py::none(), - py::arg("context") = py::none(), py::arg("loc") = py::none(), - inferReturnTypesDoc); + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("properties").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("context").none() = nb::none(), + nb::arg("loc").none() = nb::none(), inferReturnTypesDoc); } }; @@ -319,9 +319,9 @@ class PyInferTypeOpInterface class PyShapedTypeComponents { public: PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} - PyShapedTypeComponents(py::list shape, MlirType elementType) + PyShapedTypeComponents(nb::list shape, MlirType elementType) : shape(std::move(shape)), elementType(elementType), ranked(true) {} - PyShapedTypeComponents(py::list shape, MlirType elementType, + PyShapedTypeComponents(nb::list shape, MlirType elementType, MlirAttribute attribute) : shape(std::move(shape)), elementType(elementType), attribute(attribute), ranked(true) {} @@ -330,10 +330,9 @@ class PyShapedTypeComponents { : shape(other.shape), elementType(other.elementType), attribute(other.attribute), ranked(other.ranked) {} - static void bind(py::module &m) { - py::class_(m, "ShapedTypeComponents", - py::module_local()) - .def_property_readonly( + static void bind(nb::module_ &m) { + nb::class_(m, "ShapedTypeComponents") + .def_prop_ro( "element_type", [](PyShapedTypeComponents &self) { return self.elementType; }, "Returns the element type of the shaped type components.") @@ -342,57 +341,57 @@ class PyShapedTypeComponents { [](PyType &elementType) { return PyShapedTypeComponents(elementType); }, - py::arg("element_type"), + nb::arg("element_type"), "Create an shaped type components object with only the element " "type.") .def_static( "get", - [](py::list shape, PyType &elementType) { + [](nb::list shape, PyType &elementType) { return PyShapedTypeComponents(std::move(shape), elementType); }, - py::arg("shape"), py::arg("element_type"), + nb::arg("shape"), nb::arg("element_type"), "Create a ranked shaped type components object.") .def_static( "get", - [](py::list shape, PyType &elementType, PyAttribute &attribute) { + [](nb::list shape, PyType &elementType, PyAttribute &attribute) { return PyShapedTypeComponents(std::move(shape), elementType, attribute); }, - py::arg("shape"), py::arg("element_type"), py::arg("attribute"), + nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"), "Create a ranked shaped type components object with attribute.") - .def_property_readonly( + .def_prop_ro( "has_rank", [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, "Returns whether the given shaped type component is ranked.") - .def_property_readonly( + .def_prop_ro( "rank", - [](PyShapedTypeComponents &self) -> py::object { + [](PyShapedTypeComponents &self) -> nb::object { if (!self.ranked) { - return py::none(); + return nb::none(); } - return py::int_(self.shape.size()); + return nb::int_(self.shape.size()); }, "Returns the rank of the given ranked shaped type components. If " "the shaped type components does not have a rank, None is " "returned.") - .def_property_readonly( + .def_prop_ro( "shape", - [](PyShapedTypeComponents &self) -> py::object { + [](PyShapedTypeComponents &self) -> nb::object { if (!self.ranked) { - return py::none(); + return nb::none(); } - return py::list(self.shape); + return nb::list(self.shape); }, "Returns the shape of the ranked shaped type components as a list " "of integers. Returns none if the shaped type component does not " "have a rank."); } - pybind11::object getCapsule(); - static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); + nb::object getCapsule(); + static PyShapedTypeComponents createFromCapsule(nb::object capsule); private: - py::list shape; + nb::list shape; MlirType elementType; MlirAttribute attribute; bool ranked{false}; @@ -424,7 +423,7 @@ class PyInferShapedTypeOpInterface if (!hasRank) { data->inferredShapedTypeComponents.emplace_back(elementType); } else { - py::list shapeList; + nb::list shapeList; for (intptr_t i = 0; i < rank; ++i) { shapeList.append(shape[i]); } @@ -436,7 +435,7 @@ class PyInferShapedTypeOpInterface /// Given the arguments required to build an operation, attempts to infer the /// shaped type components. Throws value_error on failure. std::vector inferReturnTypeComponents( - std::optional operandList, + std::optional operandList, std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { @@ -458,7 +457,7 @@ class PyInferShapedTypeOpInterface mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { - throw py::value_error("Failed to infer result shape type components"); + throw nb::value_error("Failed to infer result shape type components"); } return inferredShapedTypeComponents; @@ -467,14 +466,16 @@ class PyInferShapedTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypeComponents", &PyInferShapedTypeOpInterface::inferReturnTypeComponents, - py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), py::arg("regions") = py::none(), - py::arg("properties") = py::none(), py::arg("context") = py::none(), - py::arg("loc") = py::none(), inferReturnTypeComponentsDoc); + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("properties").none() = nb::none(), + nb::arg("context").none() = nb::none(), + nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc); } }; -void populateIRInterfaces(py::module &m) { +void populateIRInterfaces(nb::module_ &m) { PyInferTypeOpInterface::bind(m); PyShapedTypeComponents::bind(m); PyInferShapedTypeOpInterface::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 6727860c094a2a..416a14218f125d 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -7,16 +7,19 @@ //===----------------------------------------------------------------------===// #include "IRModule.h" -#include "Globals.h" -#include "PybindUtils.h" -#include "mlir-c/Bindings/Python/Interop.h" -#include "mlir-c/Support.h" +#include +#include #include #include -namespace py = pybind11; +#include "Globals.h" +#include "NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Support.h" + +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -41,14 +44,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { return true; // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; - py::object loaded = py::none(); + nb::object loaded = nb::none(); for (std::string moduleName : localSearchPrefixes) { moduleName.push_back('.'); moduleName.append(dialectNamespace.data(), dialectNamespace.size()); try { - loaded = py::module::import(moduleName.c_str()); - } catch (py::error_already_set &e) { + loaded = nb::module_::import_(moduleName.c_str()); + } catch (nb::python_error &e) { if (e.matches(PyExc_ModuleNotFoundError)) { continue; } @@ -66,41 +69,39 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, - py::function pyFunc, bool replace) { - py::object &found = attributeBuilderMap[attributeKind]; + nb::callable pyFunc, bool replace) { + nb::object &found = attributeBuilderMap[attributeKind]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + attributeKind + "' is already registered with func: " + - py::str(found).operator std::string()) + nb::cast(nb::str(found))) .str()); } found = std::move(pyFunc); } void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, - pybind11::function typeCaster, - bool replace) { - pybind11::object &found = typeCasterMap[mlirTypeID]; + nb::callable typeCaster, bool replace) { + nb::object &found = typeCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Type caster is already registered with caster: " + - py::str(found).operator std::string()); + nb::cast(nb::str(found))); found = std::move(typeCaster); } void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, - pybind11::function valueCaster, - bool replace) { - pybind11::object &found = valueCasterMap[mlirTypeID]; + nb::callable valueCaster, bool replace) { + nb::object &found = valueCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Value caster is already registered: " + - py::repr(found).cast()); + nb::cast(nb::repr(found))); found = std::move(valueCaster); } void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, - py::object pyClass) { - py::object &found = dialectClassMap[dialectNamespace]; + nb::object pyClass) { + nb::object &found = dialectClassMap[dialectNamespace]; if (found) { throw std::runtime_error((llvm::Twine("Dialect namespace '") + dialectNamespace + "' is already registered.") @@ -110,8 +111,8 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, } void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, bool replace) { - py::object &found = operationClassMap[operationName]; + nb::object pyClass, bool replace) { + nb::object &found = operationClassMap[operationName]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + "' is already registered.") @@ -120,7 +121,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, found = std::move(pyClass); } -std::optional +std::optional PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { const auto foundIt = attributeBuilderMap.find(attributeKind); if (foundIt != attributeBuilderMap.end()) { @@ -130,7 +131,7 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { return std::nullopt; } -std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, +std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); @@ -142,7 +143,7 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, return std::nullopt; } -std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, +std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); @@ -154,7 +155,7 @@ std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, return std::nullopt; } -std::optional +std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. if (!loadDialectModule(dialectNamespace)) @@ -168,7 +169,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { return std::nullopt; } -std::optional +std::optional PyGlobals::lookupOperationClass(llvm::StringRef operationName) { // Make sure dialect module is loaded. auto split = operationName.split('.'); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 172898cfda0c52..a242ff26bbbf57 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -10,20 +10,22 @@ #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H #define MLIR_BINDINGS_PYTHON_IRMODULES_H +#include +#include + #include #include #include #include "Globals.h" -#include "PybindUtils.h" - +#include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/DenseMap.h" namespace mlir { @@ -49,7 +51,7 @@ class PyValue; template class PyObjectRef { public: - PyObjectRef(T *referrent, pybind11::object object) + PyObjectRef(T *referrent, nanobind::object object) : referrent(referrent), object(std::move(object)) { assert(this->referrent && "cannot construct PyObjectRef with null referrent"); @@ -67,13 +69,13 @@ class PyObjectRef { int getRefCount() { if (!object) return 0; - return object.ref_count(); + return Py_REFCNT(object.ptr()); } /// Releases the object held by this instance, returning it. /// This is the proper thing to return from a function that wants to return /// the reference. Note that this does not work from initializers. - pybind11::object releaseObject() { + nanobind::object releaseObject() { assert(referrent && object); referrent = nullptr; auto stolen = std::move(object); @@ -85,7 +87,7 @@ class PyObjectRef { assert(referrent && object); return referrent; } - pybind11::object getObject() { + nanobind::object getObject() { assert(referrent && object); return object; } @@ -93,7 +95,7 @@ class PyObjectRef { private: T *referrent; - pybind11::object object; + nanobind::object object; }; /// Tracks an entry in the thread context stack. New entries are pushed onto @@ -112,9 +114,9 @@ class PyThreadContextEntry { Location, }; - PyThreadContextEntry(FrameKind frameKind, pybind11::object context, - pybind11::object insertionPoint, - pybind11::object location) + PyThreadContextEntry(FrameKind frameKind, nanobind::object context, + nanobind::object insertionPoint, + nanobind::object location) : context(std::move(context)), insertionPoint(std::move(insertionPoint)), location(std::move(location)), frameKind(frameKind) {} @@ -134,26 +136,26 @@ class PyThreadContextEntry { /// Stack management. static PyThreadContextEntry *getTopOfStack(); - static pybind11::object pushContext(PyMlirContext &context); + static nanobind::object pushContext(nanobind::object context); static void popContext(PyMlirContext &context); - static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); + static nanobind::object pushInsertionPoint(nanobind::object insertionPoint); static void popInsertionPoint(PyInsertionPoint &insertionPoint); - static pybind11::object pushLocation(PyLocation &location); + static nanobind::object pushLocation(nanobind::object location); static void popLocation(PyLocation &location); /// Gets the thread local stack. static std::vector &getStack(); private: - static void push(FrameKind frameKind, pybind11::object context, - pybind11::object insertionPoint, pybind11::object location); + static void push(FrameKind frameKind, nanobind::object context, + nanobind::object insertionPoint, nanobind::object location); /// An object reference to the PyContext. - pybind11::object context; + nanobind::object context; /// An object reference to the current insertion point. - pybind11::object insertionPoint; + nanobind::object insertionPoint; /// An object reference to the current location. - pybind11::object location; + nanobind::object location; // The kind of push that was performed. FrameKind frameKind; }; @@ -163,14 +165,15 @@ using PyMlirContextRef = PyObjectRef; class PyMlirContext { public: PyMlirContext() = delete; + PyMlirContext(MlirContext context); PyMlirContext(const PyMlirContext &) = delete; PyMlirContext(PyMlirContext &&) = delete; - /// For the case of a python __init__ (py::init) method, pybind11 is quite - /// strict about needing to return a pointer that is not yet associated to - /// an py::object. Since the forContext() method acts like a pool, possibly - /// returning a recycled context, it does not satisfy this need. The usual - /// way in python to accomplish such a thing is to override __new__, but + /// For the case of a python __init__ (nanobind::init) method, pybind11 is + /// quite strict about needing to return a pointer that is not yet associated + /// to an nanobind::object. Since the forContext() method acts like a pool, + /// possibly returning a recycled context, it does not satisfy this need. The + /// usual way in python to accomplish such a thing is to override __new__, but /// that is also not supported by pybind11. Instead, we use this entry /// point which always constructs a fresh context (which cannot alias an /// existing one because it is fresh). @@ -187,17 +190,17 @@ class PyMlirContext { /// Gets a strong reference to this context, which will ensure it is kept /// alive for the life of the reference. PyMlirContextRef getRef() { - return PyMlirContextRef(this, pybind11::cast(this)); + return PyMlirContextRef(this, nanobind::cast(this)); } /// Gets a capsule wrapping the void* within the MlirContext. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. /// Note that PyMlirContext instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirContext /// is taken by calling this function. - static pybind11::object createFromCapsule(pybind11::object capsule); + static nanobind::object createFromCapsule(nanobind::object capsule); /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); @@ -237,14 +240,14 @@ class PyMlirContext { size_t getLiveModuleCount(); /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); + static nanobind::object contextEnter(nanobind::object context); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); /// Attaches a Python callback as a diagnostic handler, returning a /// registration object (internally a PyDiagnosticHandler). - pybind11::object attachDiagnosticHandler(pybind11::object callback); + nanobind::object attachDiagnosticHandler(nanobind::object callback); /// Controls whether error diagnostics should be propagated to diagnostic /// handlers, instead of being captured by `ErrorCapture`. @@ -252,8 +255,6 @@ class PyMlirContext { struct ErrorCapture; private: - PyMlirContext(MlirContext context); - // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, // preserving the relationship that an MlirContext maps to a single // PyMlirContext wrapper. This could be replaced in the future with an @@ -268,7 +269,7 @@ class PyMlirContext { // from this map, and while it still exists as an instance, any // attempt to access it will raise an error. using LiveModuleMap = - llvm::DenseMap>; + llvm::DenseMap>; LiveModuleMap liveModules; // Interns all live operations associated with this context. Operations @@ -276,7 +277,7 @@ class PyMlirContext { // removed from this map, and while it still exists as an instance, any // attempt to access it will raise an error. using LiveOperationMap = - llvm::DenseMap>; + llvm::DenseMap>; LiveOperationMap liveOperations; bool emitErrorDiagnostics = false; @@ -324,19 +325,19 @@ class PyLocation : public BaseContextObject { MlirLocation get() const { return loc; } /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); + static nanobind::object contextEnter(nanobind::object location); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); /// Gets a capsule wrapping the void* within the MlirLocation. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyLocation from the MlirLocation wrapped by a capsule. /// Note that PyLocation instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirLocation /// is taken by calling this function. - static PyLocation createFromCapsule(pybind11::object capsule); + static PyLocation createFromCapsule(nanobind::object capsule); private: MlirLocation loc; @@ -353,8 +354,8 @@ class PyDiagnostic { bool isValid() { return valid; } MlirDiagnosticSeverity getSeverity(); PyLocation getLocation(); - pybind11::str getMessage(); - pybind11::tuple getNotes(); + nanobind::str getMessage(); + nanobind::tuple getNotes(); /// Materialized diagnostic information. This is safe to access outside the /// diagnostic callback. @@ -373,7 +374,7 @@ class PyDiagnostic { /// If notes have been materialized from the diagnostic, then this will /// be populated with the corresponding objects (all castable to /// PyDiagnostic). - std::optional materializedNotes; + std::optional materializedNotes; bool valid = true; }; @@ -398,7 +399,7 @@ class PyDiagnostic { /// is no way to attach an existing handler object). class PyDiagnosticHandler { public: - PyDiagnosticHandler(MlirContext context, pybind11::object callback); + PyDiagnosticHandler(MlirContext context, nanobind::object callback); ~PyDiagnosticHandler(); bool isAttached() { return registeredID.has_value(); } @@ -407,16 +408,16 @@ class PyDiagnosticHandler { /// Detaches the handler. Does nothing if not attached. void detach(); - pybind11::object contextEnter() { return pybind11::cast(this); } - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { + nanobind::object contextEnter() { return nanobind::cast(this); } + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb) { detach(); } private: MlirContext context; - pybind11::object callback; + nanobind::object callback; std::optional registeredID; bool hadError = false; friend class PyMlirContext; @@ -477,12 +478,12 @@ class PyDialects : public BaseContextObject { /// objects of this type will be returned directly. class PyDialect { public: - PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} + PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {} - pybind11::object getDescriptor() { return descriptor; } + nanobind::object getDescriptor() { return descriptor; } private: - pybind11::object descriptor; + nanobind::object descriptor; }; /// Wrapper around an MlirDialectRegistry. @@ -505,8 +506,8 @@ class PyDialectRegistry { operator MlirDialectRegistry() const { return registry; } MlirDialectRegistry get() const { return registry; } - pybind11::object getCapsule(); - static PyDialectRegistry createFromCapsule(pybind11::object capsule); + nanobind::object getCapsule(); + static PyDialectRegistry createFromCapsule(nanobind::object capsule); private: MlirDialectRegistry registry; @@ -542,26 +543,25 @@ class PyModule : public BaseContextObject { /// Gets a strong reference to this module. PyModuleRef getRef() { - return PyModuleRef(this, - pybind11::reinterpret_borrow(handle)); + return PyModuleRef(this, nanobind::borrow(handle)); } /// Gets a capsule wrapping the void* within the MlirModule. /// Note that the module does not (yet) provide a corresponding factory for /// constructing from a capsule as that would require uniquing PyModule /// instances, which is not currently done. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyModule from the MlirModule wrapped by a capsule. /// Note that PyModule instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirModule /// is taken by calling this function. - static pybind11::object createFromCapsule(pybind11::object capsule); + static nanobind::object createFromCapsule(nanobind::object capsule); private: PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; - pybind11::handle handle; + nanobind::handle handle; }; class PyAsmState; @@ -574,18 +574,18 @@ class PyOperationBase { /// Implements the bound 'print' method and helps with others. void print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, py::object fileObject, bool binary, + bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions); - void print(PyAsmState &state, py::object fileObject, bool binary); + void print(PyAsmState &state, nanobind::object fileObject, bool binary); - pybind11::object getAsm(bool binary, + nanobind::object getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions); // Implement the bound 'writeBytecode' method. - void writeBytecode(const pybind11::object &fileObject, + void writeBytecode(const nanobind::object &fileObject, std::optional bytecodeVersion); // Implement the walk method. @@ -621,13 +621,13 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// it with a parentKeepAlive. static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive = pybind11::object()); + nanobind::object parentKeepAlive = nanobind::object()); /// Creates a detached operation. The operation must not be associated with /// any existing live operation. static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive = pybind11::object()); + nanobind::object parentKeepAlive = nanobind::object()); /// Parses a source string (either text assembly or bytecode), creating a /// detached operation. @@ -640,7 +640,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { void detachFromParent() { mlirOperationRemoveFromParent(getOperation()); setDetached(); - parentKeepAlive = pybind11::object(); + parentKeepAlive = nanobind::object(); } /// Gets the backing operation. @@ -651,12 +651,11 @@ class PyOperation : public PyOperationBase, public BaseContextObject { } PyOperationRef getRef() { - return PyOperationRef( - this, pybind11::reinterpret_borrow(handle)); + return PyOperationRef(this, nanobind::borrow(handle)); } bool isAttached() { return attached; } - void setAttached(const pybind11::object &parent = pybind11::object()) { + void setAttached(const nanobind::object &parent = nanobind::object()) { assert(!attached && "operation already attached"); attached = true; } @@ -675,24 +674,24 @@ class PyOperation : public PyOperationBase, public BaseContextObject { std::optional getParentOperation(); /// Gets a capsule wrapping the void* within the MlirOperation. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyOperation from the MlirOperation wrapped by a capsule. /// Ownership of the underlying MlirOperation is taken by calling this /// function. - static pybind11::object createFromCapsule(pybind11::object capsule); + static nanobind::object createFromCapsule(nanobind::object capsule); /// Creates an operation. See corresponding python docstring. - static pybind11::object + static nanobind::object create(const std::string &name, std::optional> results, std::optional> operands, - std::optional attributes, + std::optional attributes, std::optional> successors, int regions, - DefaultingPyLocation location, const pybind11::object &ip, + DefaultingPyLocation location, const nanobind::object &ip, bool inferType); /// Creates an OpView suitable for this operation. - pybind11::object createOpView(); + nanobind::object createOpView(); /// Erases the underlying MlirOperation, removes its pointer from the /// parent context's live operations map, and sets the valid bit false. @@ -702,23 +701,23 @@ class PyOperation : public PyOperationBase, public BaseContextObject { void setInvalid() { valid = false; } /// Clones this operation. - pybind11::object clone(const pybind11::object &ip); + nanobind::object clone(const nanobind::object &ip); private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive); + nanobind::object parentKeepAlive); MlirOperation operation; - pybind11::handle handle; + nanobind::handle handle; // Keeps the parent alive, regardless of whether it is an Operation or // Module. // TODO: As implemented, this facility is only sufficient for modeling the // trivial module parent back-reference. Generalize this to also account for // transitions from detached to attached and address TODOs in the // ir_operation.py regarding testing corresponding lifetime guarantees. - pybind11::object parentKeepAlive; + nanobind::object parentKeepAlive; bool attached = true; bool valid = true; @@ -733,17 +732,17 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// python types. class PyOpView : public PyOperationBase { public: - PyOpView(const pybind11::object &operationObject); + PyOpView(const nanobind::object &operationObject); PyOperation &getOperation() override { return operation; } - pybind11::object getOperationObject() { return operationObject; } + nanobind::object getOperationObject() { return operationObject; } - static pybind11::object buildGeneric( - const pybind11::object &cls, std::optional resultTypeList, - pybind11::list operandList, std::optional attributes, + static nanobind::object buildGeneric( + const nanobind::object &cls, std::optional resultTypeList, + nanobind::list operandList, std::optional attributes, std::optional> successors, std::optional regions, DefaultingPyLocation location, - const pybind11::object &maybeIp); + const nanobind::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its /// `__init__` method. The derived class will typically define a constructor @@ -752,12 +751,12 @@ class PyOpView : public PyOperationBase { /// /// The caller is responsible for verifying that `operation` is a valid /// operation to construct `cls` with. - static pybind11::object constructDerived(const pybind11::object &cls, - const PyOperation &operation); + static nanobind::object constructDerived(const nanobind::object &cls, + const nanobind::object &operation); private: PyOperation &operation; // For efficient, cast-free access from C++ - pybind11::object operationObject; // Holds the reference. + nanobind::object operationObject; // Holds the reference. }; /// Wrapper around an MlirRegion. @@ -830,7 +829,7 @@ class PyBlock { void checkValid() { return parentOperation->checkValid(); } /// Gets a capsule wrapping the void* within the MlirBlock. - pybind11::object getCapsule(); + nanobind::object getCapsule(); private: PyOperationRef parentOperation; @@ -858,10 +857,10 @@ class PyInsertionPoint { void insert(PyOperationBase &operationBase); /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); + static nanobind::object contextEnter(nanobind::object insertionPoint); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); PyBlock &getBlock() { return block; } std::optional &getRefOperation() { return refOperation; } @@ -886,13 +885,13 @@ class PyType : public BaseContextObject { MlirType get() const { return type; } /// Gets a capsule wrapping the void* within the MlirType. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyType from the MlirType wrapped by a capsule. /// Note that PyType instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirType /// is taken by calling this function. - static PyType createFromCapsule(pybind11::object capsule); + static PyType createFromCapsule(nanobind::object capsule); private: MlirType type; @@ -912,10 +911,10 @@ class PyTypeID { MlirTypeID get() { return typeID; } /// Gets a capsule wrapping the void* within the MlirTypeID. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. - static PyTypeID createFromCapsule(pybind11::object capsule); + static PyTypeID createFromCapsule(nanobind::object capsule); private: MlirTypeID typeID; @@ -932,7 +931,7 @@ class PyConcreteType : public BaseTy { // Derived classes must define statics for: // IsAFunctionTy isaFunction // const char *pyClassName - using ClassTy = pybind11::class_; + using ClassTy = nanobind::class_; using IsAFunctionTy = bool (*)(MlirType); using GetTypeIDFunctionTy = MlirTypeID (*)(); static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; @@ -945,34 +944,38 @@ class PyConcreteType : public BaseTy { static MlirType castFrom(PyType &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw py::value_error((llvm::Twine("Cannot cast type to ") + - DerivedTy::pyClassName + " (from " + origRepr + - ")") - .str()); + auto origRepr = + nanobind::cast(nanobind::repr(nanobind::cast(orig))); + throw nanobind::value_error((llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + " (from " + + origRepr + ")") + .str() + .c_str()); } return orig; } - static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), - pybind11::arg("cast_from_type")); + static void bind(nanobind::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), + nanobind::arg("cast_from_type")); cls.def_static( "isinstance", [](PyType &otherType) -> bool { return DerivedTy::isaFunction(otherType); }, - pybind11::arg("other")); - cls.def_property_readonly_static( - "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + nanobind::arg("other")); + cls.def_prop_ro_static( + "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { if (DerivedTy::getTypeIdFunction) return DerivedTy::getTypeIdFunction(); - throw py::attribute_error( - (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); + throw nanobind::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) + .str() + .c_str()); }); - cls.def_property_readonly("typeid", [](PyType &self) { - return py::cast(self).attr("typeid").cast(); + cls.def_prop_ro("typeid", [](PyType &self) { + return nanobind::cast(nanobind::cast(self).attr("typeid")); }); cls.def("__repr__", [](DerivedTy &self) { PyPrintAccumulator printAccum; @@ -986,8 +989,8 @@ class PyConcreteType : public BaseTy { if (DerivedTy::getTypeIdFunction) { PyGlobals::get().registerTypeCaster( DerivedTy::getTypeIdFunction(), - pybind11::cpp_function( - [](PyType pyType) -> DerivedTy { return pyType; })); + nanobind::cast(nanobind::cpp_function( + [](PyType pyType) -> DerivedTy { return pyType; }))); } DerivedTy::bindDerived(cls); @@ -1008,13 +1011,13 @@ class PyAttribute : public BaseContextObject { MlirAttribute get() const { return attr; } /// Gets a capsule wrapping the void* within the MlirAttribute. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. /// Note that PyAttribute instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAttribute /// is taken by calling this function. - static PyAttribute createFromCapsule(pybind11::object capsule); + static PyAttribute createFromCapsule(nanobind::object capsule); private: MlirAttribute attr; @@ -1054,7 +1057,7 @@ class PyConcreteAttribute : public BaseTy { // Derived classes must define statics for: // IsAFunctionTy isaFunction // const char *pyClassName - using ClassTy = pybind11::class_; + using ClassTy = nanobind::class_; using IsAFunctionTy = bool (*)(MlirAttribute); using GetTypeIDFunctionTy = MlirTypeID (*)(); static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; @@ -1067,37 +1070,45 @@ class PyConcreteAttribute : public BaseTy { static MlirAttribute castFrom(PyAttribute &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw py::value_error((llvm::Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + " (from " + origRepr + - ")") - .str()); + auto origRepr = + nanobind::cast(nanobind::repr(nanobind::cast(orig))); + throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + " (from " + + origRepr + ")") + .str() + .c_str()); } return orig; } - static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), - pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), - pybind11::arg("cast_from_attr")); + static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) { + ClassTy cls; + if (slots) { + cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots)); + } else { + cls = ClassTy(m, DerivedTy::pyClassName); + } + cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), + nanobind::arg("cast_from_attr")); cls.def_static( "isinstance", [](PyAttribute &otherAttr) -> bool { return DerivedTy::isaFunction(otherAttr); }, - pybind11::arg("other")); - cls.def_property_readonly( + nanobind::arg("other")); + cls.def_prop_ro( "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); - cls.def_property_readonly_static( - "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + cls.def_prop_ro_static( + "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { if (DerivedTy::getTypeIdFunction) return DerivedTy::getTypeIdFunction(); - throw py::attribute_error( - (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); + throw nanobind::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) + .str() + .c_str()); }); - cls.def_property_readonly("typeid", [](PyAttribute &self) { - return py::cast(self).attr("typeid").cast(); + cls.def_prop_ro("typeid", [](PyAttribute &self) { + return nanobind::cast(nanobind::cast(self).attr("typeid")); }); cls.def("__repr__", [](DerivedTy &self) { PyPrintAccumulator printAccum; @@ -1112,9 +1123,10 @@ class PyConcreteAttribute : public BaseTy { if (DerivedTy::getTypeIdFunction) { PyGlobals::get().registerTypeCaster( DerivedTy::getTypeIdFunction(), - pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { - return pyAttribute; - })); + nanobind::cast( + nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { + return pyAttribute; + }))); } DerivedTy::bindDerived(cls); @@ -1146,13 +1158,13 @@ class PyValue { void checkValid() { return parentOperation->checkValid(); } /// Gets a capsule wrapping the void* within the MlirValue. - pybind11::object getCapsule(); + nanobind::object getCapsule(); - pybind11::object maybeDownCast(); + nanobind::object maybeDownCast(); /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of /// the underlying MlirValue is still tied to the owning operation. - static PyValue createFromCapsule(pybind11::object capsule); + static PyValue createFromCapsule(nanobind::object capsule); private: PyOperationRef parentOperation; @@ -1169,13 +1181,13 @@ class PyAffineExpr : public BaseContextObject { MlirAffineExpr get() const { return affineExpr; } /// Gets a capsule wrapping the void* within the MlirAffineExpr. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. /// Note that PyAffineExpr instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr /// is taken by calling this function. - static PyAffineExpr createFromCapsule(pybind11::object capsule); + static PyAffineExpr createFromCapsule(nanobind::object capsule); PyAffineExpr add(const PyAffineExpr &other) const; PyAffineExpr mul(const PyAffineExpr &other) const; @@ -1196,13 +1208,13 @@ class PyAffineMap : public BaseContextObject { MlirAffineMap get() const { return affineMap; } /// Gets a capsule wrapping the void* within the MlirAffineMap. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. /// Note that PyAffineMap instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineMap /// is taken by calling this function. - static PyAffineMap createFromCapsule(pybind11::object capsule); + static PyAffineMap createFromCapsule(nanobind::object capsule); private: MlirAffineMap affineMap; @@ -1217,12 +1229,12 @@ class PyIntegerSet : public BaseContextObject { MlirIntegerSet get() const { return integerSet; } /// Gets a capsule wrapping the void* within the MlirIntegerSet. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. /// Note that PyIntegerSet instances may be uniqued, so the returned object /// may be a pre-existing object. Integer sets are owned by the context. - static PyIntegerSet createFromCapsule(pybind11::object capsule); + static PyIntegerSet createFromCapsule(nanobind::object capsule); private: MlirIntegerSet integerSet; @@ -1239,7 +1251,7 @@ class PySymbolTable { /// Returns the symbol (opview) with the given name, throws if there is no /// such symbol in the table. - pybind11::object dunderGetItem(const std::string &name); + nanobind::object dunderGetItem(const std::string &name); /// Removes the given operation from the symbol table and erases it. void erase(PyOperationBase &symbol); @@ -1269,7 +1281,7 @@ class PySymbolTable { /// Walks all symbol tables under and including 'from'. static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, - pybind11::object callback); + nanobind::object callback); /// Casts the bindings class into the C API structure. operator MlirSymbolTable() { return symbolTable; } @@ -1289,16 +1301,16 @@ struct MLIRError { std::vector errorDiagnostics; }; -void populateIRAffine(pybind11::module &m); -void populateIRAttributes(pybind11::module &m); -void populateIRCore(pybind11::module &m); -void populateIRInterfaces(pybind11::module &m); -void populateIRTypes(pybind11::module &m); +void populateIRAffine(nanobind::module_ &m); +void populateIRAttributes(nanobind::module_ &m); +void populateIRCore(nanobind::module_ &m); +void populateIRInterfaces(nanobind::module_ &m); +void populateIRTypes(nanobind::module_ &m); } // namespace python } // namespace mlir -namespace pybind11 { +namespace nanobind { namespace detail { template <> @@ -1309,6 +1321,6 @@ struct type_caster : MlirDefaultingCaster {}; } // namespace detail -} // namespace pybind11 +} // namespace nanobind #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 6f192bc4bffeef..5cfa51142ac08f 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -6,19 +6,26 @@ // //===----------------------------------------------------------------------===// +// clang-format off #include "IRModule.h" +#include "mlir/Bindings/Python/IRTypes.h" +// clang-format on -#include "PybindUtils.h" +#include +#include +#include +#include +#include -#include "mlir/Bindings/Python/IRTypes.h" +#include +#include "IRModule.h" +#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Support.h" -#include - -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -48,7 +55,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - py::arg("width"), py::arg("context") = py::none(), + nb::arg("width"), nb::arg("context").none() = nb::none(), "Create a signless integer type"); c.def_static( "get_signed", @@ -56,7 +63,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeSignedGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - py::arg("width"), py::arg("context") = py::none(), + nb::arg("width"), nb::arg("context").none() = nb::none(), "Create a signed integer type"); c.def_static( "get_unsigned", @@ -64,25 +71,25 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - py::arg("width"), py::arg("context") = py::none(), + nb::arg("width"), nb::arg("context").none() = nb::none(), "Create an unsigned integer type"); - c.def_property_readonly( + c.def_prop_ro( "width", [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, "Returns the width of the integer type"); - c.def_property_readonly( + c.def_prop_ro( "is_signless", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSignless(self); }, "Returns whether this is a signless integer"); - c.def_property_readonly( + c.def_prop_ro( "is_signed", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); }, "Returns whether this is a signed integer"); - c.def_property_readonly( + c.def_prop_ro( "is_unsigned", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsUnsigned(self); @@ -107,7 +114,7 @@ class PyIndexType : public PyConcreteType { MlirType t = mlirIndexTypeGet(context->get()); return PyIndexType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a index type."); + nb::arg("context").none() = nb::none(), "Create a index type."); } }; @@ -118,7 +125,7 @@ class PyFloatType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_property_readonly( + c.def_prop_ro( "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, "Returns the width of the floating-point type"); } @@ -141,7 +148,7 @@ class PyFloat4E2M1FNType MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); return PyFloat4E2M1FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float4_e2m1fn type."); + nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type."); } }; @@ -162,7 +169,7 @@ class PyFloat6E2M3FNType MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); return PyFloat6E2M3FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float6_e2m3fn type."); + nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type."); } }; @@ -183,7 +190,7 @@ class PyFloat6E3M2FNType MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); return PyFloat6E3M2FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float6_e3m2fn type."); + nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type."); } }; @@ -204,7 +211,7 @@ class PyFloat8E4M3FNType MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); return PyFloat8E4M3FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3fn type."); + nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type."); } }; @@ -224,7 +231,7 @@ class PyFloat8E5M2Type : public PyConcreteType { MlirType t = mlirFloat8E5M2TypeGet(context->get()); return PyFloat8E5M2Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e5m2 type."); + nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type."); } }; @@ -244,7 +251,7 @@ class PyFloat8E4M3Type : public PyConcreteType { MlirType t = mlirFloat8E4M3TypeGet(context->get()); return PyFloat8E4M3Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3 type."); + nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type."); } }; @@ -265,7 +272,8 @@ class PyFloat8E4M3FNUZType MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); return PyFloat8E4M3FNUZType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e4m3fnuz type."); } }; @@ -286,7 +294,8 @@ class PyFloat8E4M3B11FNUZType MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); return PyFloat8E4M3B11FNUZType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e4m3b11fnuz type."); } }; @@ -307,7 +316,8 @@ class PyFloat8E5M2FNUZType MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); return PyFloat8E5M2FNUZType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e5m2fnuz type."); } }; @@ -327,7 +337,7 @@ class PyFloat8E3M4Type : public PyConcreteType { MlirType t = mlirFloat8E3M4TypeGet(context->get()); return PyFloat8E3M4Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e3m4 type."); + nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type."); } }; @@ -348,7 +358,8 @@ class PyFloat8E8M0FNUType MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); return PyFloat8E8M0FNUType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e8m0fnu type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e8m0fnu type."); } }; @@ -368,7 +379,7 @@ class PyBF16Type : public PyConcreteType { MlirType t = mlirBF16TypeGet(context->get()); return PyBF16Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a bf16 type."); + nb::arg("context").none() = nb::none(), "Create a bf16 type."); } }; @@ -388,7 +399,7 @@ class PyF16Type : public PyConcreteType { MlirType t = mlirF16TypeGet(context->get()); return PyF16Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a f16 type."); + nb::arg("context").none() = nb::none(), "Create a f16 type."); } }; @@ -408,7 +419,7 @@ class PyTF32Type : public PyConcreteType { MlirType t = mlirTF32TypeGet(context->get()); return PyTF32Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a tf32 type."); + nb::arg("context").none() = nb::none(), "Create a tf32 type."); } }; @@ -428,7 +439,7 @@ class PyF32Type : public PyConcreteType { MlirType t = mlirF32TypeGet(context->get()); return PyF32Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a f32 type."); + nb::arg("context").none() = nb::none(), "Create a f32 type."); } }; @@ -448,7 +459,7 @@ class PyF64Type : public PyConcreteType { MlirType t = mlirF64TypeGet(context->get()); return PyF64Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a f64 type."); + nb::arg("context").none() = nb::none(), "Create a f64 type."); } }; @@ -468,7 +479,7 @@ class PyNoneType : public PyConcreteType { MlirType t = mlirNoneTypeGet(context->get()); return PyNoneType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a none type."); + nb::arg("context").none() = nb::none(), "Create a none type."); } }; @@ -490,14 +501,15 @@ class PyComplexType : public PyConcreteType { MlirType t = mlirComplexTypeGet(elementType); return PyComplexType(elementType.getContext(), t); } - throw py::value_error( + throw nb::value_error( (Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + + nb::cast(nb::repr(nb::cast(elementType))) + "' and expected floating point or integer type.") - .str()); + .str() + .c_str()); }, "Create a complex type"); - c.def_property_readonly( + c.def_prop_ro( "element_type", [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, "Returns element type."); @@ -508,22 +520,22 @@ class PyComplexType : public PyConcreteType { // Shaped Type Interface - ShapedType void mlir::PyShapedType::bindDerived(ClassTy &c) { - c.def_property_readonly( + c.def_prop_ro( "element_type", [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, "Returns the element type of the shaped type."); - c.def_property_readonly( + c.def_prop_ro( "has_rank", [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, "Returns whether the given shaped type is ranked."); - c.def_property_readonly( + c.def_prop_ro( "rank", [](PyShapedType &self) { self.requireHasRank(); return mlirShapedTypeGetRank(self); }, "Returns the rank of the given ranked shaped type."); - c.def_property_readonly( + c.def_prop_ro( "has_static_shape", [](PyShapedType &self) -> bool { return mlirShapedTypeHasStaticShape(self); @@ -535,7 +547,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeIsDynamicDim(self, dim); }, - py::arg("dim"), + nb::arg("dim"), "Returns whether the dim-th dimension of the given shaped type is " "dynamic."); c.def( @@ -544,12 +556,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeGetDimSize(self, dim); }, - py::arg("dim"), + nb::arg("dim"), "Returns the dim-th dimension of the given ranked shaped type."); c.def_static( "is_dynamic_size", [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, - py::arg("dim_size"), + nb::arg("dim_size"), "Returns whether the given dimension size indicates a dynamic " "dimension."); c.def( @@ -558,10 +570,10 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeIsDynamicStrideOrOffset(val); }, - py::arg("dim_size"), + nb::arg("dim_size"), "Returns whether the given value is used as a placeholder for dynamic " "strides and offsets in shaped types."); - c.def_property_readonly( + c.def_prop_ro( "shape", [](PyShapedType &self) { self.requireHasRank(); @@ -587,7 +599,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { void mlir::PyShapedType::requireHasRank() { if (!mlirShapedTypeHasRank(*this)) { - throw py::value_error( + throw nb::value_error( "calling this method requires that the type has a rank."); } } @@ -607,15 +619,15 @@ class PyVectorType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_static("get", &PyVectorType::get, py::arg("shape"), - py::arg("element_type"), py::kw_only(), - py::arg("scalable") = py::none(), - py::arg("scalable_dims") = py::none(), - py::arg("loc") = py::none(), "Create a vector type") - .def_property_readonly( + c.def_static("get", &PyVectorType::get, nb::arg("shape"), + nb::arg("element_type"), nb::kw_only(), + nb::arg("scalable").none() = nb::none(), + nb::arg("scalable_dims").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a vector type") + .def_prop_ro( "scalable", [](MlirType self) { return mlirVectorTypeIsScalable(self); }) - .def_property_readonly("scalable_dims", [](MlirType self) { + .def_prop_ro("scalable_dims", [](MlirType self) { std::vector scalableDims; size_t rank = static_cast(mlirShapedTypeGetRank(self)); scalableDims.reserve(rank); @@ -627,11 +639,11 @@ class PyVectorType : public PyConcreteType { private: static PyVectorType get(std::vector shape, PyType &elementType, - std::optional scalable, + std::optional scalable, std::optional> scalableDims, DefaultingPyLocation loc) { if (scalable && scalableDims) { - throw py::value_error("'scalable' and 'scalable_dims' kwargs " + throw nb::value_error("'scalable' and 'scalable_dims' kwargs " "are mutually exclusive."); } @@ -639,10 +651,10 @@ class PyVectorType : public PyConcreteType { MlirType type; if (scalable) { if (scalable->size() != shape.size()) - throw py::value_error("Expected len(scalable) == len(shape)."); + throw nb::value_error("Expected len(scalable) == len(shape)."); SmallVector scalableDimFlags = llvm::to_vector(llvm::map_range( - *scalable, [](const py::handle &h) { return h.cast(); })); + *scalable, [](const nb::handle &h) { return nb::cast(h); })); type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType); @@ -650,7 +662,7 @@ class PyVectorType : public PyConcreteType { SmallVector scalableDimFlags(shape.size(), false); for (int64_t dim : *scalableDims) { if (static_cast(dim) >= scalableDimFlags.size() || dim < 0) - throw py::value_error("Scalable dimension index out of bounds."); + throw nb::value_error("Scalable dimension index out of bounds."); scalableDimFlags[dim] = true; } type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), @@ -689,17 +701,17 @@ class PyRankedTensorType throw MLIRError("Invalid type", errors.take()); return PyRankedTensorType(elementType.getContext(), t); }, - py::arg("shape"), py::arg("element_type"), - py::arg("encoding") = py::none(), py::arg("loc") = py::none(), - "Create a ranked tensor type"); - c.def_property_readonly( - "encoding", - [](PyRankedTensorType &self) -> std::optional { - MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); - if (mlirAttributeIsNull(encoding)) - return std::nullopt; - return encoding; - }); + nb::arg("shape"), nb::arg("element_type"), + nb::arg("encoding").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a ranked tensor type"); + c.def_prop_ro("encoding", + [](PyRankedTensorType &self) -> std::optional { + MlirAttribute encoding = + mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return std::nullopt; + return encoding; + }); } }; @@ -723,7 +735,7 @@ class PyUnrankedTensorType throw MLIRError("Invalid type", errors.take()); return PyUnrankedTensorType(elementType.getContext(), t); }, - py::arg("element_type"), py::arg("loc") = py::none(), + nb::arg("element_type"), nb::arg("loc").none() = nb::none(), "Create a unranked tensor type"); } }; @@ -754,10 +766,11 @@ class PyMemRefType : public PyConcreteType { throw MLIRError("Invalid type", errors.take()); return PyMemRefType(elementType.getContext(), t); }, - py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), - py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly( + nb::arg("shape"), nb::arg("element_type"), + nb::arg("layout").none() = nb::none(), + nb::arg("memory_space").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a memref type") + .def_prop_ro( "layout", [](PyMemRefType &self) -> MlirAttribute { return mlirMemRefTypeGetLayout(self); @@ -775,14 +788,14 @@ class PyMemRefType : public PyConcreteType { return {strides, offset}; }, "The strides and offset of the MemRef type.") - .def_property_readonly( + .def_prop_ro( "affine_map", [](PyMemRefType &self) -> PyAffineMap { MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); return PyAffineMap(self.getContext(), map); }, "The layout of the MemRef type as an affine map.") - .def_property_readonly( + .def_prop_ro( "memory_space", [](PyMemRefType &self) -> std::optional { MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); @@ -820,9 +833,9 @@ class PyUnrankedMemRefType throw MLIRError("Invalid type", errors.take()); return PyUnrankedMemRefType(elementType.getContext(), t); }, - py::arg("element_type"), py::arg("memory_space"), - py::arg("loc") = py::none(), "Create a unranked memref type") - .def_property_readonly( + nb::arg("element_type"), nb::arg("memory_space").none(), + nb::arg("loc").none() = nb::none(), "Create a unranked memref type") + .def_prop_ro( "memory_space", [](PyUnrankedMemRefType &self) -> std::optional { MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); @@ -851,15 +864,15 @@ class PyTupleType : public PyConcreteType { elements.data()); return PyTupleType(context->getRef(), t); }, - py::arg("elements"), py::arg("context") = py::none(), + nb::arg("elements"), nb::arg("context").none() = nb::none(), "Create a tuple type"); c.def( "get_type", [](PyTupleType &self, intptr_t pos) { return mlirTupleTypeGetType(self, pos); }, - py::arg("pos"), "Returns the pos-th type in the tuple type."); - c.def_property_readonly( + nb::arg("pos"), "Returns the pos-th type in the tuple type."); + c.def_prop_ro( "num_types", [](PyTupleType &self) -> intptr_t { return mlirTupleTypeGetNumTypes(self); @@ -887,13 +900,14 @@ class PyFunctionType : public PyConcreteType { results.size(), results.data()); return PyFunctionType(context->getRef(), t); }, - py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), + nb::arg("inputs"), nb::arg("results"), + nb::arg("context").none() = nb::none(), "Gets a FunctionType from a list of input and result types"); - c.def_property_readonly( + c.def_prop_ro( "inputs", [](PyFunctionType &self) { MlirType t = self; - py::list types; + nb::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; ++i) { types.append(mlirFunctionTypeGetInput(t, i)); @@ -901,10 +915,10 @@ class PyFunctionType : public PyConcreteType { return types; }, "Returns the list of input types in the FunctionType."); - c.def_property_readonly( + c.def_prop_ro( "results", [](PyFunctionType &self) { - py::list types; + nb::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; ++i) { types.append(mlirFunctionTypeGetResult(self, i)); @@ -938,21 +952,21 @@ class PyOpaqueType : public PyConcreteType { toMlirStringRef(typeData)); return PyOpaqueType(context->getRef(), type); }, - py::arg("dialect_namespace"), py::arg("buffer"), - py::arg("context") = py::none(), + nb::arg("dialect_namespace"), nb::arg("buffer"), + nb::arg("context").none() = nb::none(), "Create an unregistered (opaque) dialect type."); - c.def_property_readonly( + c.def_prop_ro( "dialect_namespace", [](PyOpaqueType &self) { MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the dialect namespace for the Opaque type as a string."); - c.def_property_readonly( + c.def_prop_ro( "data", [](PyOpaqueType &self) { MlirStringRef stringRef = mlirOpaqueTypeGetData(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the data for the Opaque type as a string."); } @@ -960,7 +974,7 @@ class PyOpaqueType : public PyConcreteType { } // namespace -void mlir::python::populateIRTypes(py::module &m) { +void mlir::python::populateIRTypes(nb::module_ &m) { PyIntegerType::bind(m); PyFloatType::bind(m); PyIndexType::bind(m); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 7c27021902de31..e5e64a921a79ad 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,29 +6,31 @@ // //===----------------------------------------------------------------------===// -#include "PybindUtils.h" +#include +#include #include "Globals.h" #include "IRModule.h" +#include "NanobindUtils.h" #include "Pass.h" #include "Rewrite.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; -using namespace py::literals; +using namespace nb::literals; using namespace mlir::python; // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlir, m) { +NB_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; - py::class_(m, "_Globals", py::module_local()) - .def_property("dialect_search_modules", - &PyGlobals::getDialectSearchPrefixes, - &PyGlobals::setDialectSearchPrefixes) + nb::class_(m, "_Globals") + .def_prop_rw("dialect_search_modules", + &PyGlobals::getDialectSearchPrefixes, + &PyGlobals::setDialectSearchPrefixes) .def( "append_dialect_search_prefix", [](PyGlobals &self, std::string moduleName) { @@ -45,22 +47,21 @@ PYBIND11_MODULE(_mlir, m) { "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, py::kw_only(), + "operation_name"_a, "operation_class"_a, nb::kw_only(), "replace"_a = false, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage // it is necessary to make sure it is destroyed (and releases its python // resources) properly. - m.attr("globals") = - py::cast(new PyGlobals, py::return_value_policy::take_ownership); + m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership); // Registration decorators. m.def( "register_dialect", - [](py::type pyClass) { + [](nb::type_object pyClass) { std::string dialectNamespace = - pyClass.attr("DIALECT_NAMESPACE").cast(); + nanobind::cast(pyClass.attr("DIALECT_NAMESPACE")); PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); return pyClass; }, @@ -68,45 +69,46 @@ PYBIND11_MODULE(_mlir, m) { "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](const py::type &dialectClass, bool replace) -> py::cpp_function { - return py::cpp_function( - [dialectClass, replace](py::type opClass) -> py::type { + [](const nb::type_object &dialectClass, bool replace) -> nb::object { + return nb::cpp_function( + [dialectClass, + replace](nb::type_object opClass) -> nb::type_object { std::string operationName = - opClass.attr("OPERATION_NAME").cast(); + nanobind::cast(opClass.attr("OPERATION_NAME")); PyGlobals::get().registerOperationImpl(operationName, opClass, replace); // Dict-stuff the new opClass by name onto the dialect class. - py::object opClassName = opClass.attr("__name__"); + nb::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; return opClass; }); }, - "dialect_class"_a, py::kw_only(), "replace"_a = false, + "dialect_class"_a, nb::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); m.def( MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { - return py::cpp_function([mlirTypeID, - replace](py::object typeCaster) -> py::object { + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function([mlirTypeID, replace]( + nb::callable typeCaster) -> nb::object { PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); return typeCaster; }); }, - "typeid"_a, py::kw_only(), "replace"_a = false, + "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { - return py::cpp_function( - [mlirTypeID, replace](py::object valueCaster) -> py::object { + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function( + [mlirTypeID, replace](nb::callable valueCaster) -> nb::object { PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, replace); return valueCaster; }); }, - "typeid"_a, py::kw_only(), "replace"_a = false, + "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h similarity index 85% rename from mlir/lib/Bindings/Python/PybindUtils.h rename to mlir/lib/Bindings/Python/NanobindUtils.h index 38462ac8ba6db9..3b0f7f698b22d4 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -1,4 +1,5 @@ -//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// +//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++ +//-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -9,13 +10,21 @@ #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H +#include + #include "mlir-c/Support.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" -#include -#include +template <> +struct std::iterator_traits { + using value_type = nanobind::handle; + using reference = const value_type; + using pointer = void; + using difference_type = std::ptrdiff_t; + using iterator_category = std::forward_iterator_tag; +}; namespace mlir { namespace python { @@ -54,14 +63,14 @@ class Defaulting { } // namespace python } // namespace mlir -namespace pybind11 { +namespace nanobind { namespace detail { template struct MlirDefaultingCaster { - PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); + NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)); - bool load(pybind11::handle src, bool) { + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { if (src.is_none()) { // Note that we do want an exception to propagate from here as it will be // the most informative. @@ -76,20 +85,20 @@ struct MlirDefaultingCaster { // code to produce nice error messages (other than "Cannot cast..."). try { value = DefaultingTy{ - pybind11::cast(src)}; + nanobind::cast(src)}; return true; } catch (std::exception &) { return false; } } - static handle cast(DefaultingTy src, return_value_policy policy, - handle parent) { - return pybind11::cast(src, policy); + static handle from_cpp(DefaultingTy src, rv_policy policy, + cleanup_list *cleanup) noexcept { + return nanobind::cast(src, policy); } }; } // namespace detail -} // namespace pybind11 +} // namespace nanobind //------------------------------------------------------------------------------ // Conversion utilities. @@ -100,7 +109,7 @@ namespace mlir { /// Accumulates into a python string from a method that accepts an /// MlirStringCallback. struct PyPrintAccumulator { - pybind11::list parts; + nanobind::list parts; void *getUserData() { return this; } @@ -108,15 +117,15 @@ struct PyPrintAccumulator { return [](MlirStringRef part, void *userData) { PyPrintAccumulator *printAccum = static_cast(userData); - pybind11::str pyPart(part.data, + nanobind::str pyPart(part.data, part.length); // Decodes as UTF-8 by default. printAccum->parts.append(std::move(pyPart)); }; } - pybind11::str join() { - pybind11::str delim("", 0); - return delim.attr("join")(parts); + nanobind::str join() { + nanobind::str delim("", 0); + return nanobind::cast(delim.attr("join")(parts)); } }; @@ -124,21 +133,21 @@ struct PyPrintAccumulator { /// or binary. class PyFileAccumulator { public: - PyFileAccumulator(const pybind11::object &fileObject, bool binary) + PyFileAccumulator(const nanobind::object &fileObject, bool binary) : pyWriteFunction(fileObject.attr("write")), binary(binary) {} void *getUserData() { return this; } MlirStringCallback getCallback() { return [](MlirStringRef part, void *userData) { - pybind11::gil_scoped_acquire acquire; + nanobind::gil_scoped_acquire acquire; PyFileAccumulator *accum = static_cast(userData); if (accum->binary) { // Note: Still has to copy and not avoidable with this API. - pybind11::bytes pyBytes(part.data, part.length); + nanobind::bytes pyBytes(part.data, part.length); accum->pyWriteFunction(pyBytes); } else { - pybind11::str pyStr(part.data, + nanobind::str pyStr(part.data, part.length); // Decodes as UTF-8 by default. accum->pyWriteFunction(pyStr); } @@ -146,7 +155,7 @@ class PyFileAccumulator { } private: - pybind11::object pyWriteFunction; + nanobind::object pyWriteFunction; bool binary; }; @@ -163,17 +172,17 @@ struct PySinglePartStringAccumulator { assert(!accum->invoked && "PySinglePartStringAccumulator called back multiple times"); accum->invoked = true; - accum->value = pybind11::str(part.data, part.length); + accum->value = nanobind::str(part.data, part.length); }; } - pybind11::str takeValue() { + nanobind::str takeValue() { assert(invoked && "PySinglePartStringAccumulator not called back"); return std::move(value); } private: - pybind11::str value; + nanobind::str value; bool invoked = false; }; @@ -208,7 +217,7 @@ struct PySinglePartStringAccumulator { template class Sliceable { protected: - using ClassTy = pybind11::class_; + using ClassTy = nanobind::class_; /// Transforms `index` into a legal value to access the underlying sequence. /// Returns <0 on failure. @@ -237,7 +246,7 @@ class Sliceable { /// Returns the element at the given slice index. Supports negative indices /// by taking elements in inverse order. Returns a nullptr object if out /// of bounds. - pybind11::object getItem(intptr_t index) { + nanobind::object getItem(intptr_t index) { // Negative indices mean we count from the end. index = wrapIndex(index); if (index < 0) { @@ -250,20 +259,20 @@ class Sliceable { ->getRawElement(linearizeIndex(index)) .maybeDownCast(); else - return pybind11::cast( + return nanobind::cast( static_cast(this)->getRawElement(linearizeIndex(index))); } /// Returns a new instance of the pseudo-container restricted to the given /// slice. Returns a nullptr object on failure. - pybind11::object getItemSlice(PyObject *slice) { + nanobind::object getItemSlice(PyObject *slice) { ssize_t start, stop, extraStep, sliceLength; if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, &sliceLength) != 0) { PyErr_SetString(PyExc_IndexError, "index out of range"); return {}; } - return pybind11::cast(static_cast(this)->slice( + return nanobind::cast(static_cast(this)->slice( startIndex + start * step, sliceLength, step * extraStep)); } @@ -279,7 +288,7 @@ class Sliceable { // Negative indices mean we count from the end. index = wrapIndex(index); if (index < 0) { - throw pybind11::index_error("index out of range"); + throw nanobind::index_error("index out of range"); } return static_cast(this)->getRawElement(linearizeIndex(index)); @@ -304,39 +313,38 @@ class Sliceable { } /// Binds the indexing and length methods in the Python class. - static void bind(pybind11::module &m) { - auto clazz = pybind11::class_(m, Derived::pyClassName, - pybind11::module_local()) + static void bind(nanobind::module_ &m) { + auto clazz = nanobind::class_(m, Derived::pyClassName) .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); // Manually implement the sequence protocol via the C API. We do this - // because it is approx 4x faster than via pybind11, largely because that + // because it is approx 4x faster than via nanobind, largely because that // formulation requires a C++ exception to be thrown to detect end of // sequence. // Since we are in a C-context, any C++ exception that happens here // will terminate the program. There is nothing in this implementation // that should throw in a non-terminal way, so we forgo further // exception marshalling. - // See: https://github.com/pybind/pybind11/issues/2842 + // See: https://github.com/pybind/nanobind/issues/2842 auto heap_type = reinterpret_cast(clazz.ptr()); assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && "must be heap type"); heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { - auto self = pybind11::cast(rawSelf); + auto self = nanobind::cast(nanobind::handle(rawSelf)); return self->length; }; // sq_item is called as part of the sequence protocol for iteration, // list construction, etc. heap_type->as_sequence.sq_item = +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { - auto self = pybind11::cast(rawSelf); + auto self = nanobind::cast(nanobind::handle(rawSelf)); return self->getItem(index).release().ptr(); }; // mp_subscript is used for both slices and integer lookups. heap_type->as_mapping.mp_subscript = +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { - auto self = pybind11::cast(rawSelf); + auto self = nanobind::cast(nanobind::handle(rawSelf)); Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); if (!PyErr_Occurred()) { // Integer indexing. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e991deaae2daa5..b5dce4fe4128a5 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,12 +8,16 @@ #include "Pass.h" +#include +#include +#include + #include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Pass.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlir; using namespace mlir::python; @@ -34,16 +38,15 @@ class PyPassManager { MlirPassManager get() { return passManager; } void release() { passManager.ptr = nullptr; } - pybind11::object getCapsule() { - return py::reinterpret_steal( - mlirPythonPassManagerToCapsule(get())); + nb::object getCapsule() { + return nb::steal(mlirPythonPassManagerToCapsule(get())); } - static pybind11::object createFromCapsule(pybind11::object capsule) { + static nb::object createFromCapsule(nb::object capsule) { MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); if (mlirPassManagerIsNull(rawPm)) - throw py::error_already_set(); - return py::cast(PyPassManager(rawPm), py::return_value_policy::move); + throw nb::python_error(); + return nb::cast(PyPassManager(rawPm), nb::rv_policy::move); } private: @@ -53,22 +56,23 @@ class PyPassManager { } // namespace /// Create the `mlir.passmanager` here. -void mlir::python::populatePassManagerSubmodule(py::module &m) { +void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- - py::class_(m, "PassManager", py::module_local()) - .def(py::init<>([](const std::string &anchorOp, - DefaultingPyMlirContext context) { - MlirPassManager passManager = mlirPassManagerCreateOnOperation( - context->get(), - mlirStringRefCreate(anchorOp.data(), anchorOp.size())); - return new PyPassManager(passManager); - }), - "anchor_op"_a = py::str("any"), "context"_a = py::none(), - "Create a new PassManager for the current (or provided) Context.") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyPassManager::getCapsule) + nb::class_(m, "PassManager") + .def( + "__init__", + [](PyPassManager &self, const std::string &anchorOp, + DefaultingPyMlirContext context) { + MlirPassManager passManager = mlirPassManagerCreateOnOperation( + context->get(), + mlirStringRefCreate(anchorOp.data(), anchorOp.size())); + new (&self) PyPassManager(passManager); + }, + "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(), + "Create a new PassManager for the current (or provided) Context.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) .def("_testing_release", &PyPassManager::release, "Releases (leaks) the backing pass manager (testing)") @@ -101,9 +105,9 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "print_before_all"_a = false, "print_after_all"_a = true, "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, - "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false, - "print_generic_op_form"_a = false, - "tree_printing_dir_path"_a = py::none(), + "large_elements_limit"_a.none() = nb::none(), + "enable_debug_info"_a = false, "print_generic_op_form"_a = false, + "tree_printing_dir_path"_a.none() = nb::none(), "Enable IR printing, default as mlir-print-ir-after-all.") .def( "enable_verifier", @@ -121,10 +125,10 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw py::value_error(std::string(errorMsg.join())); + throw nb::value_error(errorMsg.join().c_str()); return new PyPassManager(passManager); }, - "pipeline"_a, "context"_a = py::none(), + "pipeline"_a, "context"_a.none() = nb::none(), "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") @@ -137,7 +141,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw py::value_error(std::string(errorMsg.join())); + throw nb::value_error(errorMsg.join().c_str()); }, "pipeline"_a, "Add textual pipeline elements to the pass manager. Throws a " diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h index 3a500d5e8257ac..bc409435218299 100644 --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -9,12 +9,12 @@ #ifndef MLIR_BINDINGS_PYTHON_PASS_H #define MLIR_BINDINGS_PYTHON_PASS_H -#include "PybindUtils.h" +#include "NanobindUtils.h" namespace mlir { namespace python { -void populatePassManagerSubmodule(pybind11::module &m); +void populatePassManagerSubmodule(nanobind::module_ &m); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 1d8128be9f0826..b2c1de4be9a69c 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -8,14 +8,16 @@ #include "Rewrite.h" +#include + #include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Rewrite.h" #include "mlir/Config/mlir-config.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; -using namespace py::literals; +using namespace nb::literals; using namespace mlir::python; namespace { @@ -54,18 +56,17 @@ class PyFrozenRewritePatternSet { } MlirFrozenRewritePatternSet get() { return set; } - pybind11::object getCapsule() { - return py::reinterpret_steal( + nb::object getCapsule() { + return nb::steal( mlirPythonFrozenRewritePatternSetToCapsule(get())); } - static pybind11::object createFromCapsule(pybind11::object capsule) { + static nb::object createFromCapsule(nb::object capsule) { MlirFrozenRewritePatternSet rawPm = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); if (rawPm.ptr == nullptr) - throw py::error_already_set(); - return py::cast(PyFrozenRewritePatternSet(rawPm), - py::return_value_policy::move); + throw nb::python_error(); + return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); } private: @@ -75,25 +76,27 @@ class PyFrozenRewritePatternSet { } // namespace /// Create the `mlir.rewrite` here. -void mlir::python::populateRewriteSubmodule(py::module &m) { +void mlir::python::populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH - py::class_(m, "PDLModule", py::module_local()) - .def(py::init<>([](MlirModule module) { - return mlirPDLPatternModuleFromModule(module); - }), - "module"_a, "Create a PDL module from the given module.") + nb::class_(m, "PDLModule") + .def( + "__init__", + [](PyPDLPatternModule &self, MlirModule module) { + new (&self) + PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); + }, + "module"_a, "Create a PDL module from the given module.") .def("freeze", [](PyPDLPatternModule &self) { return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }); -#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg - py::class_(m, "FrozenRewritePatternSet", - py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyFrozenRewritePatternSet::getCapsule) +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + nb::class_(m, "FrozenRewritePatternSet") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyFrozenRewritePatternSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyFrozenRewritePatternSet::createFromCapsule); m.def( @@ -102,7 +105,7 @@ void mlir::python::populateRewriteSubmodule(py::module &m) { auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); if (mlirLogicalResultIsFailure(status)) // FIXME: Not sure this is the right error to throw here. - throw py::value_error("pattern application failed to converge"); + throw nb::value_error("pattern application failed to converge"); }, "module"_a, "set"_a, "Applys the given patterns to the given module greedily while folding " diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h index 997b80adda3038..ae89e2b9589f13 100644 --- a/mlir/lib/Bindings/Python/Rewrite.h +++ b/mlir/lib/Bindings/Python/Rewrite.h @@ -9,12 +9,12 @@ #ifndef MLIR_BINDINGS_PYTHON_REWRITE_H #define MLIR_BINDINGS_PYTHON_REWRITE_H -#include "PybindUtils.h" +#include "NanobindUtils.h" namespace mlir { namespace python { -void populateRewriteSubmodule(pybind11::module &m); +void populateRewriteSubmodule(nanobind::module_ &m); } // namespace python } // namespace mlir diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index e1b870b53ad25c..d3ca940b408276 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -440,6 +440,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MODULE_NAME _mlir ADD_TO_PARENT MLIRPythonSources.Core ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES MainModule.cpp IRAffine.cpp @@ -455,7 +456,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core Globals.h IRModule.h Pass.h - PybindUtils.h + NanobindUtils.h Rewrite.h PRIVATE_LINK_LIBS LLVMSupport diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index ab8a9122919e19..f240d6ef944ec7 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ -nanobind>=2.0, <3.0 +nanobind>=2.4, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py index 577721ab2111f5..8b6d7ea5a197d7 100644 --- a/mlir/test/python/ir/symbol_table.py +++ b/mlir/test/python/ir/symbol_table.py @@ -176,5 +176,6 @@ def error_callback(symbol_table_op, uses_visible): try: SymbolTable.walk_symbol_tables(m.operation, True, error_callback) except RuntimeError as e: - # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python + # CHECK: GOT EXCEPTION: Exception raised in callback: + # CHECK: AssertionError: Raised from python print(f"GOT EXCEPTION: {e}") diff --git a/utils/bazel/WORKSPACE b/utils/bazel/WORKSPACE index 66ba1ac1b17e1e..005a4b9d7b5ad2 100644 --- a/utils/bazel/WORKSPACE +++ b/utils/bazel/WORKSPACE @@ -161,9 +161,9 @@ maybe( http_archive, name = "nanobind", build_file = "@llvm-raw//utils/bazel/third_party_build:nanobind.BUILD", - sha256 = "bfbfc7e5759f1669e4ddb48752b1ddc5647d1430e94614d6f8626df1d508e65a", - strip_prefix = "nanobind-2.2.0", - url = "https://github.com/wjakob/nanobind/archive/refs/tags/v2.2.0.tar.gz", + sha256 = "bb35deaed7efac5029ed1e33880a415638352f757d49207a8e6013fefb6c49a7", + strip_prefix = "nanobind-2.4.0", + url = "https://github.com/wjakob/nanobind/archive/refs/tags/v2.4.0.tar.gz", ) load("@rules_python//python:repositories.bzl", "py_repositories", "python_register_toolchains") diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 544becfa30b40f..aee8aab8498ce2 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1026,6 +1026,9 @@ cc_library( srcs = [":MLIRBindingsPythonSourceFiles"], copts = PYBIND11_COPTS, features = PYBIND11_FEATURES, + includes = [ + "lib/Bindings/Python", + ], textual_hdrs = [":MLIRBindingsPythonCoreHeaders"], deps = [ ":CAPIAsync", @@ -1033,11 +1036,11 @@ cc_library( ":CAPIIR", ":CAPIInterfaces", ":CAPITransforms", - ":MLIRBindingsPythonHeadersAndDeps", + ":MLIRBindingsPythonNanobindHeadersAndDeps", ":Support", ":config", "//llvm:Support", - "@pybind11", + "@nanobind", "@rules_python//python/cc:current_py_cc_headers", ], ) @@ -1047,17 +1050,20 @@ cc_library( srcs = [":MLIRBindingsPythonSourceFiles"], copts = PYBIND11_COPTS, features = PYBIND11_FEATURES, + includes = [ + "lib/Bindings/Python", + ], textual_hdrs = [":MLIRBindingsPythonCoreHeaders"], deps = [ ":CAPIAsyncHeaders", ":CAPIDebugHeaders", ":CAPIIRHeaders", ":CAPITransformsHeaders", - ":MLIRBindingsPythonHeaders", + ":MLIRBindingsPythonNanobindHeaders", ":Support", ":config", "//llvm:Support", - "@pybind11", + "@nanobind", "@rules_python//python/cc:current_py_cc_headers", ], ) @@ -1090,6 +1096,7 @@ cc_binary( deps = [ ":MLIRBindingsPythonCore", ":MLIRBindingsPythonHeadersAndDeps", + "@nanobind", ], )