llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Oleksandr "Alex" Zinenko (ftynse) <details> <summary>Changes</summary> Port the bindings for non-shaped builtin types in IRTypes.cpp to use the `mlir_type_subclass` mechanism used by non-builtin types. This is part of a longer-term cleanup to only support one subclassing mechanism. Eventually, the `PyConcreteType` mechanism will be removed. This required a surgery in the type casters and the `mlir_type_subclass` logic to avoid circular imports of the `_mlir.ir` module that would otherwise when using `mlir_type_subclass` to define classes in the `_mlir.ir` module. Tests are updated to use the `.get_static_typeid()` function instead of the `.static_typeid` property that was specific to builtin types due to the `PyConcreteType` mechanism. The change should be NFC otherwise. --- Patch is 49.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171143.diff 6 Files Affected: - (modified) mlir/include/mlir/Bindings/Python/NanobindAdaptors.h (+30-11) - (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+366-663) - (modified) mlir/lib/Bindings/Python/MainModule.cpp (+15) - (modified) mlir/test/python/dialects/arith_dialect.py (+4-4) - (modified) mlir/test/python/ir/builtin_types.py (+7-4) - (modified) mlir/test/python/ir/value.py (+3-3) ``````````diff diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 6594670abaaa7..f678f57527e97 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -371,16 +371,22 @@ struct type_caster<MlirTypeID> { } return false; } - static handle from_cpp(MlirTypeID v, rv_policy, - cleanup_list *cleanup) noexcept { + + static handle + from_cpp_given_module(MlirTypeID v, + const nanobind::module_ &module) noexcept { if (v.ptr == nullptr) return nanobind::none(); nanobind::object capsule = nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v)); - return mlir::python::irModule() - .attr("TypeID") + return module.attr("TypeID") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); + } + + static handle from_cpp(MlirTypeID v, rv_policy, + cleanup_list *cleanup) noexcept { + return from_cpp_given_module(v, mlir::python::irModule()); }; }; @@ -602,9 +608,12 @@ class mlir_type_subclass : public pure_subclass { /// Subclasses by looking up the super-class dynamically. mlir_type_subclass(nanobind::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, - GetTypeIDFunctionTy getTypeIDFunction = nullptr) - : mlir_type_subclass(scope, typeClassName, isaFunction, - irModule().attr("Type"), getTypeIDFunction) {} + GetTypeIDFunctionTy getTypeIDFunction = nullptr, + const nanobind::module_ *mlirIrModule = nullptr) + : mlir_type_subclass( + scope, typeClassName, isaFunction, + (mlirIrModule != nullptr ? *mlirIrModule : irModule()).attr("Type"), + getTypeIDFunction, mlirIrModule) {} /// Subclasses with a provided mlir.ir.Type super-class. This must /// be used if the subclass is being defined in the same extension module @@ -613,7 +622,8 @@ class mlir_type_subclass : public pure_subclass { mlir_type_subclass(nanobind::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, const nanobind::object &superCls, - GetTypeIDFunctionTy getTypeIDFunction = nullptr) + GetTypeIDFunctionTy getTypeIDFunction = nullptr, + const nanobind::module_ *mlirIrModule = nullptr) : pure_subclass(scope, typeClassName, superCls) { // Casting constructor. Note that it is hard, if not impossible, to properly // call chain to parent `__init__` in nanobind due to its special handling @@ -672,9 +682,18 @@ class mlir_type_subclass : public pure_subclass { nanobind::sig("def get_static_typeid() -> " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID")) // clang-format on ); - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - getTypeIDFunction())(nanobind::cpp_function( + + // Directly call the caster implementation given the "ir" module, + // otherwise it may trigger recursive import as the default caster + // attempts to import the "ir" module. + MlirTypeID typeID = getTypeIDFunction(); + mlirIrModule = mlirIrModule ? mlirIrModule : &irModule(); + nanobind::handle pyTypeID = + nanobind::detail::type_caster<MlirTypeID>::from_cpp_given_module( + typeID, *mlirIrModule); + + mlirIrModule->attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(pyTypeID)( + nanobind::cpp_function( [thisClass = thisClass](const nanobind::object &mlirType) { return thisClass(mlirType); })); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 34c5b8dd86a66..2e4090c358c47 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -18,13 +18,13 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace mlir; using namespace mlir::python; using llvm::SmallVector; -using llvm::Twine; namespace { @@ -34,480 +34,368 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) { mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); } -class PyIntegerType : public PyConcreteType<PyIntegerType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirIntegerTypeGetTypeID; - static constexpr const char *pyClassName = "IntegerType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_signless", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context") = nb::none(), - "Create a signless integer type"); - c.def_static( - "get_signed", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeSignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context") = nb::none(), - "Create a signed integer type"); - c.def_static( - "get_unsigned", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context") = nb::none(), - "Create an unsigned integer type"); - c.def_prop_ro( - "width", - [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, - "Returns the width of the integer type"); - c.def_prop_ro( - "is_signless", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSignless(self); - }, - "Returns whether this is a signless integer"); - c.def_prop_ro( - "is_signed", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSigned(self); - }, - "Returns whether this is a signed integer"); - c.def_prop_ro( - "is_unsigned", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsUnsigned(self); - }, - "Returns whether this is an unsigned integer"); - } -}; - -/// Index Type subclass - IndexType. -class PyIndexType : public PyConcreteType<PyIndexType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirIndexTypeGetTypeID; - static constexpr const char *pyClassName = "IndexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirIndexTypeGet(context->get()); - return PyIndexType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a index type."); - } -}; - -class PyFloatType : public PyConcreteType<PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; - static constexpr const char *pyClassName = "FloatType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_prop_ro( - "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, - "Returns the width of the floating-point type"); - } -}; - -/// Floating Point Type subclass - Float4E2M1FNType. -class PyFloat4E2M1FNType - : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat4E2M1FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float4E2M1FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); - return PyFloat4E2M1FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float4_e2m1fn type."); - } -}; - -/// Floating Point Type subclass - Float6E2M3FNType. -class PyFloat6E2M3FNType - : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat6E2M3FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float6E2M3FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); - return PyFloat6E2M3FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float6_e2m3fn type."); - } -}; - -/// Floating Point Type subclass - Float6E3M2FNType. -class PyFloat6E3M2FNType - : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat6E3M2FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float6E3M2FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); - return PyFloat6E3M2FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float6_e3m2fn type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNType. -class PyFloat8E4M3FNType - : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); - return PyFloat8E4M3FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3fn type."); - } -}; - -/// Floating Point Type subclass - Float8E5M2Type. -class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E5M2TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E5M2Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2TypeGet(context->get()); - return PyFloat8E5M2Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e5m2 type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3Type. -class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3TypeGet(context->get()); - return PyFloat8E4M3Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3 type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNUZ. -class PyFloat8E4M3FNUZType - : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); - return PyFloat8E4M3FNUZType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3B11FNUZ. -class PyFloat8E4M3B11FNUZType - : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3B11FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); - return PyFloat8E4M3B11FNUZType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E5M2FNUZ. -class PyFloat8E5M2FNUZType - : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E5M2FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E5M2FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); - return PyFloat8E5M2FNUZType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E3M4Type. -class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E3M4TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E3M4Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E3M4TypeGet(context->get()); - return PyFloat8E3M4Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e3m4 type."); - } -}; - -/// Floating Point Type subclass - Float8E8M0FNUType. -class PyFloat8E8M0FNUType - : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E8M0FNUTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E8M0FNUType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); - return PyFloat8E8M0FNUType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type."); - } -}; - -/// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirBFloat16TypeGetTypeID; - static constexpr const char *pyClassName = "BF16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirBF16TypeGet(context->get()); - return PyBF16Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a bf16 type."); - } -}; - -/// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat16TypeGetTypeID; - static constexpr const char *pyClassName = "F16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF16TypeGet(context->get()); - return PyF16Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a f16 type."); - } -}; - -/// Floating Point Type subclass - TF32Type. -class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloatTF32TypeGetTypeID; - static constexpr const char *pyClassName = "FloatTF32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirTF32TypeGet(context->get()); - return PyTF32Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a tf32 type."); - } -}; - -/// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat32TypeGetTypeID; - static constexpr const char *pyClassName = "F32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF32TypeGet(context->get()); - return PyF32Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a f32 type."); - } -}; +static void populateIRTypesModule(const nanobind::module_ &m) { + using namespace nanobind_adaptors; -/// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat64TypeGetTypeID; - static constexpr const char *pyClassName = "F64Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF64TypeGet(context->get()); - return PyF64Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a f64 type."); - } -}; - -/// None Type subclass - NoneType. -class PyNoneType : public PyConcreteType<PyNoneType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirNoneTypeGetTypeID; - static constexpr const char *pyClassName = "NoneType"; - using PyConcreteType::PyConcreteType; - - ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/171143 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
