This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 49a5d71a feat(python)!: enforce __slots__=() for Object subclasses via
_ObjectSlotsMeta (#480)
49a5d71a is described below
commit 49a5d71a3145aee20b6cfbcb7a2f7d9feb25f2f7
Author: Junru Shao <[email protected]>
AuthorDate: Fri Feb 27 12:34:54 2026 -0800
feat(python)!: enforce __slots__=() for Object subclasses via
_ObjectSlotsMeta (#480)
## Summary
- Split Cython `Object` into `CObject` (extension type) + `Object`
(Python class with `_ObjectSlotsMeta` metaclass) to enforce
`__slots__=()` across the entire Object hierarchy.
- Remove the `slots: bool = True` keyword argument from
`_ObjectSlotsMeta`; subclasses that need a `__dict__` declare `__slots__
= ("__dict__",)` explicitly.
- Simplify `Module.__getattr__` to store looked-up functions via
`setattr` (into `__dict__` via `__slots__ = ("__dict__",)`), letting
normal attribute lookup serve as cache.
- Replace module-global `_REPR_PRINT` cache with local `_ffi_api` import
in `object_repr()`.
## Breaking Changes
- All `Object` subclasses now enforce `__slots__=()` by default. Code
that sets arbitrary instance attributes will raise `AttributeError`.
Migrate by declaring needed slots explicitly (e.g., `__slots__ =
("__dict__",)`).
- `_ObjectSlotsMeta` no longer accepts `slots` keyword argument.
- `DLDeviceType.kDLTrn` renumbered from 17→18; `kDLMAIA=17` added.
## Test plan
- [x] `uv run pytest -vvs tests/python` — 532 passed, 23 skipped, 1
xfailed
- [x] All pre-commit hooks pass (ruff, ty, cython-lint, etc.)
- [x] C++ tests (no C++ changes in this PR)
- [x] Rust tests (no Rust changes in this PR)
---
python/tvm_ffi/access_path.py | 1 -
python/tvm_ffi/container.py | 26 ----
python/tvm_ffi/core.pyi | 8 +-
python/tvm_ffi/cython/base.pxi | 1 +
python/tvm_ffi/cython/device.pxi | 6 +-
python/tvm_ffi/cython/dtype.pxi | 1 +
python/tvm_ffi/cython/error.pxi | 7 +-
python/tvm_ffi/cython/function.pxi | 48 +++----
python/tvm_ffi/cython/object.pxi | 247 ++++++++++++++++++++++--------------
python/tvm_ffi/cython/string.pxi | 4 +-
python/tvm_ffi/cython/tensor.pxi | 12 +-
python/tvm_ffi/cython/type_info.pxi | 8 +-
python/tvm_ffi/module.py | 9 +-
python/tvm_ffi/registry.py | 16 ++-
tests/python/test_object.py | 43 +++++++
15 files changed, 260 insertions(+), 177 deletions(-)
diff --git a/python/tvm_ffi/access_path.py b/python/tvm_ffi/access_path.py
index 243a5d4c..258d693b 100644
--- a/python/tvm_ffi/access_path.py
+++ b/python/tvm_ffi/access_path.py
@@ -108,7 +108,6 @@ class AccessPath(Object):
def __init__(self) -> None:
"""Disallow direct construction; use `AccessPath.root()` instead."""
- super().__init__()
raise ValueError(
"AccessPath can't be initialized directly. "
"Use AccessPath.root() to create a path to the root object"
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index d53366ae..de201555 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -194,13 +194,6 @@ class Array(core.Object, Sequence[T]):
for i in range(length):
yield self[i]
- def __repr__(self) -> str:
- """Return a string representation of the array."""
- # exception safety handling for chandle=None
- if self.__chandle__() == 0:
- return type(self).__name__ + "(chandle=None)"
- return str(core.__object_repr__(self)) # ty:
ignore[unresolved-attribute]
-
def __contains__(self, value: object) -> bool:
"""Check if the array contains a value."""
return _ffi_api.ArrayContains(self, value)
@@ -341,12 +334,6 @@ class List(core.Object, MutableSequence[T]):
for i in range(length):
yield cast(T, _ffi_api.ListGetItem(self, i))
- def __repr__(self) -> str:
- """Return a string representation of the list."""
- if self.__chandle__() == 0:
- return type(self).__name__ + "(chandle=None)"
- return str(core.__object_repr__(self)) # ty:
ignore[unresolved-attribute]
-
def __contains__(self, value: object) -> bool:
"""Check if the list contains a value."""
return _ffi_api.ListContains(self, value)
@@ -555,13 +542,6 @@ class Map(core.Object, Mapping[K, V]):
return default
return ret
- def __repr__(self) -> str:
- """Return a string representation of the map."""
- # exception safety handling for chandle=None
- if self.__chandle__() == 0:
- return type(self).__name__ + "(chandle=None)"
- return str(core.__object_repr__(self)) # ty:
ignore[unresolved-attribute]
-
@register_object("ffi.Dict")
class Dict(core.Object, MutableMapping[K, V]):
@@ -672,9 +652,3 @@ class Dict(core.Object, MutableMapping[K, V]):
"""Update the dict from a mapping."""
for k, v in other.items():
self[k] = v
-
- def __repr__(self) -> str:
- """Return a string representation of the dict."""
- if self.__chandle__() == 0:
- return type(self).__name__ + "(chandle=None)"
- return str(core.__object_repr__(self)) # ty:
ignore[unresolved-attribute]
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 3ad1fff6..dd125c38 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -33,7 +33,7 @@ _TRACEBACK_TO_BACKTRACE_STR: Callable[[types.TracebackType |
None], str] | None
# DLPack protocol version (defined in tensor.pxi)
__dlpack_version__: tuple[int, int]
-class Object:
+class CObject:
def __ctypes_handle__(self) -> Any: ...
def __chandle__(self) -> int: ...
def __reduce__(self) -> Any: ...
@@ -47,7 +47,11 @@ class Object:
def __ffi_init__(self, *args: Any) -> None: ...
def same_as(self, other: Any) -> bool: ...
def _move(self) -> ObjectRValueRef: ...
- def __move_handle_from__(self, other: Object) -> None: ...
+ def __move_handle_from__(self, other: CObject) -> None: ...
+
+class Object(CObject): ...
+
+def object_repr(obj: CObject) -> str: ...
class ObjectConvertible:
def asobject(self) -> Object: ...
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 2b201eee..2c61babf 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -396,6 +396,7 @@ cdef extern from "tvm_ffi_python_helpers.h":
cdef class ByteArrayArg:
+ __slots__ = ()
cdef TVMFFIByteArray cdata
cdef object py_data
diff --git a/python/tvm_ffi/cython/device.pxi b/python/tvm_ffi/cython/device.pxi
index 2eb36fc9..f7ae3340 100644
--- a/python/tvm_ffi/cython/device.pxi
+++ b/python/tvm_ffi/cython/device.pxi
@@ -61,7 +61,8 @@ class DLDeviceType(IntEnum):
kDLOneAPI = 14
kDLWebGPU = 15
kDLHexagon = 16
- kDLTrn = 17
+ kDLMAIA = 17
+ kDLTrn = 18
cdef class Device:
@@ -91,6 +92,7 @@ cdef class Device:
assert str(dev) == "cuda:0"
"""
+ __slots__ = ()
cdef DLDevice cdevice
_DEVICE_TYPE_TO_NAME = {
@@ -108,6 +110,7 @@ cdef class Device:
DLDeviceType.kDLOneAPI: "oneapi",
DLDeviceType.kDLWebGPU: "webgpu",
DLDeviceType.kDLHexagon: "hexagon",
+ DLDeviceType.kDLMAIA: "maia",
DLDeviceType.kDLTrn: "trn",
}
@@ -127,6 +130,7 @@ cdef class Device:
"ext_dev": DLDeviceType.kDLExtDev,
"hexagon": DLDeviceType.kDLHexagon,
"webgpu": DLDeviceType.kDLWebGPU,
+ "maia": DLDeviceType.kDLMAIA,
"trn": DLDeviceType.kDLTrn,
}
diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index c320c635..6838be6b 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -79,6 +79,7 @@ cdef class DataType:
assert str(d) == "int32"
"""
+ __slots__ = ()
cdef DLDataType cdtype
def __init__(self, dtype_str: str) -> None:
diff --git a/python/tvm_ffi/cython/error.pxi b/python/tvm_ffi/cython/error.pxi
index 5c560184..6f8159cd 100644
--- a/python/tvm_ffi/cython/error.pxi
+++ b/python/tvm_ffi/cython/error.pxi
@@ -27,7 +27,7 @@ _WITH_APPEND_BACKTRACE: Optional[Callable[[BaseException,
str], BaseException]]
_TRACEBACK_TO_BACKTRACE_STR: Optional[Callable[[types.TracebackType | None],
str]] = None
-cdef class Error(Object):
+cdef class Error(CObject):
"""Base class for FFI errors.
An :class:`Error` is a lightweight wrapper around a concrete Python
@@ -43,6 +43,7 @@ cdef class Error(Object):
Do not directly raise this object. Instead, use :py:meth:`py_error`
to convert it to a Python exception and raise that.
"""
+ __slots__ = ()
def __init__(self, kind: str, message: str, backtrace: str):
"""Construct an error wrapper.
@@ -66,7 +67,7 @@ cdef class Error(Object):
)
if ret != 0:
raise MemoryError("Failed to create error object")
- (<Object>self).chandle = out
+ (<CObject>self).chandle = out
def update_backtrace(self, backtrace: str) -> None:
"""Replace the stored backtrace string with ``backtrace``.
@@ -107,7 +108,7 @@ cdef class Error(Object):
cdef inline Error move_from_last_error():
# raise last error
error = Error.__new__(Error)
- TVMFFIErrorMoveFromRaised(&(<Object>error).chandle)
+ TVMFFIErrorMoveFromRaised(&(<CObject>error).chandle)
return error
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index ec1542d2..f66d7dfb 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -131,7 +131,7 @@ cdef int TVMFFIPyArgSetterTensor_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* arg, TVMFFIAny* out
) except -1:
- if (<Object>arg).chandle != NULL:
+ if (<CObject>arg).chandle != NULL:
out.type_index = kTVMFFITensor
out.v_ptr = (<Tensor>arg).chandle
else:
@@ -144,8 +144,8 @@ cdef int TVMFFIPyArgSetterObject_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* arg, TVMFFIAny* out
) except -1:
- out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
- out.v_ptr = (<Object>arg).chandle
+ out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
+ out.v_ptr = (<CObject>arg).chandle
return 0
@@ -312,7 +312,7 @@ cdef int TVMFFIPyArgSetterFFIObjectProtocol_(
"""Setter for objects that implement the `__tvm_ffi_object__` protocol."""
cdef object arg = <object>py_arg
cdef TVMFFIObjectHandle temp_chandle
- cdef Object obj = arg.__tvm_ffi_object__()
+ cdef CObject obj = arg.__tvm_ffi_object__()
cdef long ref_count = Py_REFCNT(obj)
temp_chandle = obj.chandle
out.type_index = TVMFFIObjectGetTypeIndex(temp_chandle)
@@ -418,8 +418,8 @@ cdef int TVMFFIPyArgSetterPyNativeObjectStr_(
# need to check if the arg is a large string returned from ffi
if arg._tvm_ffi_cached_object is not None:
arg = arg._tvm_ffi_cached_object
- out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
- out.v_ptr = (<Object>arg).chandle
+ out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
+ out.v_ptr = (<CObject>arg).chandle
return 0
return TVMFFIPyArgSetterStr_(handle, ctx, py_arg, out)
@@ -457,8 +457,8 @@ cdef int TVMFFIPyArgSetterPyNativeObjectBytes_(
# need to check if the arg is a large bytes returned from ffi
if arg._tvm_ffi_cached_object is not None:
arg = arg._tvm_ffi_cached_object
- out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
- out.v_ptr = (<Object>arg).chandle
+ out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
+ out.v_ptr = (<CObject>arg).chandle
return 0
return TVMFFIPyArgSetterBytes_(handle, ctx, py_arg, out)
@@ -473,8 +473,8 @@ cdef int TVMFFIPyArgSetterPyNativeObjectGeneral_(
raise ValueError(f"_tvm_ffi_cached_object is None for {type(arg)}")
assert arg._tvm_ffi_cached_object is not None
arg = arg._tvm_ffi_cached_object
- out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
- out.v_ptr = (<Object>arg).chandle
+ out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
+ out.v_ptr = (<CObject>arg).chandle
return 0
@@ -507,7 +507,7 @@ cdef int TVMFFIPyArgSetterObjectRValueRef_(
"""Setter for ObjectRValueRef"""
cdef object arg = <object>py_arg
out.type_index = kTVMFFIObjectRValueRef
- out.v_ptr = &((<Object>(arg.obj)).chandle)
+ out.v_ptr = &((<CObject>(arg.obj)).chandle)
return 0
@@ -532,8 +532,8 @@ cdef int TVMFFIPyArgSetterException_(
"""Setter for Exception"""
cdef object arg = <object>py_arg
arg = _convert_to_ffi_error(arg)
- out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
- out.v_ptr = (<Object>arg).chandle
+ out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
+ out.v_ptr = (<CObject>arg).chandle
TVMFFIPyPushTempPyObject(ctx, <PyObject*>arg)
return 0
@@ -595,8 +595,8 @@ cdef int TVMFFIPyArgSetterObjectConvertible_(
# recursively construct a new map
cdef object arg = <object>py_arg
arg = arg.asobject()
- out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
- out.v_ptr = (<Object>arg).chandle
+ out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
+ out.v_ptr = (<CObject>arg).chandle
TVMFFIPyPushTempPyObject(ctx, <PyObject*>arg)
@@ -727,7 +727,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
if isinstance(arg, Tensor):
out.func = TVMFFIPyArgSetterTensor_
return 0
- if isinstance(arg, Object):
+ if isinstance(arg, CObject):
out.func = TVMFFIPyArgSetterObject_
return 0
if isinstance(arg, ObjectRValueRef):
@@ -857,7 +857,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
#
---------------------------------------------------------------------------------------------
# Implementation of function calling
#
---------------------------------------------------------------------------------------------
-cdef class Function(Object):
+cdef class Function(CObject):
"""Callable wrapper around a TVM FFI function.
Instances are obtained by converting Python callables with
@@ -908,7 +908,7 @@ cdef class Function(Object):
result.v_int64 = 0
TVMFFIPyFuncCall(
TVMFFIPyArgSetterFactory_,
- (<Object>self).chandle, <PyObject*>args,
+ (<CObject>self).chandle, <PyObject*>args,
&result,
&c_api_ret_code,
self.release_gil,
@@ -978,7 +978,7 @@ cdef class Function(Object):
CHECK_CALL(ret_code)
func = Function.__new__(Function)
- (<Object>func).chandle = chandle
+ (<CObject>func).chandle = chandle
return func
@staticmethod
@@ -1032,7 +1032,7 @@ cdef class Function(Object):
TVMFFIPyMLIRPackedSafeCallDeleter(mlir_packed_safe_call)
CHECK_CALL(ret_code)
func = Function.__new__(Function)
- (<Object>func).chandle = chandle
+ (<CObject>func).chandle = chandle
return func
@@ -1045,7 +1045,7 @@ def _register_global_func(name: str, pyfunc:
Callable[..., Any] | Function, over
if not isinstance(pyfunc, Function):
pyfunc = _convert_to_ffi_func(pyfunc)
- CHECK_CALL(TVMFFIFunctionSetGlobal(name_arg.cptr(),
(<Object>pyfunc).chandle, ioverride))
+ CHECK_CALL(TVMFFIFunctionSetGlobal(name_arg.cptr(),
(<CObject>pyfunc).chandle, ioverride))
return pyfunc
@@ -1056,7 +1056,7 @@ def _get_global_func(name: str, allow_missing: bool):
CHECK_CALL(TVMFFIFunctionGetGlobal(name_arg.cptr(), &chandle))
if chandle != NULL:
ret = Function.__new__(Function)
- (<Object>ret).chandle = chandle
+ (<CObject>ret).chandle = chandle
return ret
if allow_missing:
@@ -1111,7 +1111,7 @@ def _convert_to_ffi_func(object pyfunc: Callable[...,
Any]) -> Function:
cdef TVMFFIObjectHandle chandle
_convert_to_ffi_func_handle(pyfunc, &chandle)
ret = Function.__new__(Function)
- (<Object>ret).chandle = chandle
+ (<CObject>ret).chandle = chandle
return ret
@@ -1133,7 +1133,7 @@ def _convert_to_opaque_object(object pyobject: Any) ->
OpaquePyObject:
cdef TVMFFIObjectHandle chandle
_convert_to_opaque_object_handle(pyobject, &chandle)
ret = OpaquePyObject.__new__(OpaquePyObject)
- (<Object>ret).chandle = chandle
+ (<CObject>ret).chandle = chandle
return ret
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 15978387..834943c0 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import json
+from abc import ABCMeta
from typing import Any
@@ -26,23 +27,19 @@ def _set_class_object(cls):
_CLASS_OBJECT = cls
-_REPR_PRINT = None
-_REPR_PRINT_LOADED = False
+def object_repr(obj: "CObject") -> str:
+ """Return a human-readable repr of *obj* via ``ffi.ReprPrint``.
-
-def __object_repr__(obj: "Object") -> str:
- """Object repr function using ffi.ReprPrint when available."""
- global _REPR_PRINT, _REPR_PRINT_LOADED
- if not _REPR_PRINT_LOADED:
- _REPR_PRINT_LOADED = True
- _REPR_PRINT = _get_global_func("ffi.ReprPrint", False)
- if _REPR_PRINT is not None:
- try:
- return str(_REPR_PRINT(obj))
- except Exception: # noqa: BLE001
- # Silently fall back: __repr__ must never raise.
- pass
- return type(obj).__name__ + "(" + str(obj.__ctypes_handle__().value) + ")"
+ Falls back to ``TypeName(handle)`` if ``ReprPrint`` is unavailable.
+ """
+ if (<CObject>obj).chandle == NULL:
+ return type(obj).__name__ + "(chandle=None)"
+ try:
+ from tvm_ffi._ffi_api import ReprPrint
+ return str(ReprPrint(obj))
+ except Exception: # noqa: BLE001
+ # Silently fall back: repr must never raise.
+ return type(obj).__name__ + "(" + str(obj.__chandle__()) + ")"
def _new_object(cls):
@@ -91,38 +88,13 @@ class ObjectRValueRef:
self.obj = obj
-cdef class Object:
- """Base class of all TVM FFI objects.
-
- This is the root Python type for objects backed by the TVM FFI
- runtime. Each instance references a handle to a C++ runtime
- object. Python subclasses typically correspond to C++ runtime
- types and are registered via :py:meth:`tvm_ffi.register_object`.
-
- Notes
- -----
- - Equality of two :py:class:`Object` instances uses underlying handle
- identity unless an overridden implementation is provided on the
- concrete type. Use :py:meth:`same_as` to check whether two
- references point to the same underlying object.
- - Most users interact with subclasses (e.g. :class:`Tensor`,
- :class:`Function`) rather than :py:class:`Object` directly.
-
- Examples
- --------
- Constructing objects is typically performed by Python wrappers that
- call into registered constructors on the FFI side.
-
- .. code-block:: python
-
- import tvm_ffi.testing
-
- # Acquire a testing object constructed through FFI
- obj = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12)
- assert isinstance(obj, tvm_ffi.Object)
- assert obj.same_as(obj)
+cdef class CObject:
+ """Cython base class for TVM FFI objects.
+ This extension type owns the low-level handle. Prefer subclassing
+ :class:`Object` in Python to enforce slots policy.
"""
+ __slots__ = ()
cdef void* chandle
def __cinit__(self):
@@ -166,10 +138,7 @@ cdef class Object:
self.chandle = NULL
def __repr__(self) -> str:
- # exception safety handling for chandle=None
- if self.chandle == NULL:
- return type(self).__name__ + "(chandle=None)"
- return str(__object_repr__(self))
+ return object_repr(self)
def __eq__(self, other: object) -> bool:
return self.same_as(other)
@@ -177,30 +146,103 @@ cdef class Object:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
- def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) ->
None:
- """Initialize the handle by calling constructor function.
+ def __hash__(self) -> int:
+ cdef uint64_t hash_value = <uint64_t>self.chandle
+ return hash_value
- Parameters
- ----------
- fconstructor : Function
- Constructor function.
+ def same_as(self, other: object) -> bool:
+ return isinstance(other, CObject) and self.chandle ==
(<CObject>other).chandle
- args: list of objects
- The arguments to the constructor
+ def __move_handle_from__(self, other: CObject) -> None:
+ self.chandle = (<CObject>other).chandle
+ (<CObject>other).chandle = NULL
- Notes
- -----
- We have a special calling convention to call constructor functions.
- So the return handle is directly set into the Node object
- instead of creating a new Node.
- """
+ def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) ->
None:
# avoid error raised during construction.
self.chandle = NULL
cdef void* chandle
ConstructorCall(
- (<Object>fconstructor).chandle, <PyObject*>args, &chandle, NULL)
+ (<CObject>fconstructor).chandle, <PyObject*>args, &chandle, NULL)
self.chandle = chandle
+
+class _ObjectSlotsMeta(ABCMeta):
+ def __new__(mcls, name: str, bases: tuple[type, ...], ns: dict[str, Any],
**kwargs: Any):
+ if "__slots__" not in ns:
+ ns["__slots__"] = ()
+ return super().__new__(mcls, name, bases, ns, **kwargs)
+
+ def __init__(cls, name: str, bases: tuple[type, ...], ns: dict[str, Any],
**kwargs: Any):
+ super().__init__(name, bases, ns, **kwargs)
+
+ def __instancecheck__(cls, instance: Any) -> bool:
+ if isinstance(instance, CObject):
+ return True
+ return super().__instancecheck__(instance)
+
+ def __subclasscheck__(cls, subclass: type) -> bool:
+ try:
+ if issubclass(subclass, CObject):
+ return True
+ except TypeError:
+ pass
+ return super().__subclasscheck__(subclass)
+
+
+class Object(CObject, metaclass=_ObjectSlotsMeta):
+ """Base class of all TVM FFI objects.
+
+ This is the root Python type for objects backed by the TVM FFI
+ runtime. Each instance references a handle to a C++ runtime
+ object. Python subclasses typically correspond to C++ runtime
+ types and are registered via :py:meth:`tvm_ffi.register_object`.
+
+ Notes
+ -----
+ - Equality of two :py:class:`Object` instances uses underlying handle
+ identity unless an overridden implementation is provided on the
+ concrete type. Use :py:meth:`same_as` to check whether two
+ references point to the same underlying object.
+ - Subclasses that omit ``__slots__`` are treated as ``__slots__ = ()``.
+ Subclasses that need per-instance dynamic attributes can opt in with
+ ``__slots__ = ("__dict__",)``.
+ - Most users interact with subclasses (e.g. :class:`Tensor`,
+ :class:`Function`) rather than :py:class:`Object` directly.
+
+ Examples
+ --------
+ Constructing objects is typically performed by Python wrappers that
+ call into registered constructors on the FFI side.
+
+ .. code-block:: python
+
+ import tvm_ffi.testing
+
+ # Acquire a testing object constructed through FFI
+ obj = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12)
+ assert isinstance(obj, tvm_ffi.Object)
+ assert obj.same_as(obj)
+
+ Subclasses can declare explicit slots when needed.
+
+ .. code-block:: python
+
+ @tvm_ffi.register_object("my.MyObject")
+ class MyObject(tvm_ffi.Object):
+ __slots__ = ()
+
+ Subclasses that need a per-instance ``__dict__`` (e.g. for attribute
+ caching) can opt in explicitly.
+
+ .. code-block:: python
+
+ @tvm_ffi.register_object("my.MyDynObject")
+ class MyDynObject(tvm_ffi.Object):
+ __slots__ = ("__dict__",)
+
+ """
+ __slots__ = ()
+
def __ffi_init__(self, *args: Any) -> None:
"""Initialize the instance using the ``__ffi_init__`` method
registered on C++ side.
@@ -239,13 +281,7 @@ cdef class Object:
assert not x.same_as(z)
"""
- if not isinstance(other, Object):
- return False
- return self.chandle == (<Object>other).chandle
-
- def __hash__(self) -> int:
- cdef uint64_t hash_value = <uint64_t>self.chandle
- return hash_value
+ return CObject.same_as(self, other)
def _move(self) -> ObjectRValueRef:
"""Create an rvalue reference that transfers ownership.
@@ -267,17 +303,35 @@ cdef class Object:
"""
return ObjectRValueRef(self)
- def __move_handle_from__(self, other: Object) -> None:
+ def __move_handle_from__(self, other: CObject) -> None:
"""Steal the FFI handle from ``other``.
Internal helper used by the runtime to implement move
semantics. Users should prefer :py:meth:`_move`.
"""
- self.chandle = (<Object>other).chandle
- (<Object>other).chandle = NULL
+ CObject.__move_handle_from__(self, other)
+ def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) ->
None:
+ """Initialize the handle by calling constructor function.
-cdef class OpaquePyObject(Object):
+ Parameters
+ ----------
+ fconstructor : Function
+ Constructor function.
+
+ args: list of objects
+ The arguments to the constructor
+
+ Notes
+ -----
+ We have a special calling convention to call constructor functions.
+ So the return handle is directly set into the Node object
+ instead of creating a new Node.
+ """
+ CObject.__init_handle_by_constructor__(self, fconstructor, *args)
+
+
+cdef class OpaquePyObject(CObject):
"""Wrapper that carries an arbitrary Python object across the FFI.
The contained object is held with correct reference counting, and
@@ -288,6 +342,8 @@ cdef class OpaquePyObject(Object):
``OpaquePyObject`` is useful when a Python value must traverse the
FFI boundary without conversion into a native FFI type.
"""
+ __slots__ = ()
+
def pyobject(self) -> object:
"""Return the original Python object held by this wrapper."""
cdef object obj
@@ -349,7 +405,7 @@ cdef inline str _type_index_to_key(int32_t tindex):
cdef inline object make_ret_opaque_object(TVMFFIAny result):
obj = OpaquePyObject.__new__(OpaquePyObject)
- (<Object>obj).chandle = result.v_obj
+ (<CObject>obj).chandle = result.v_obj
return obj.pyobject()
cdef inline object make_fallback_cls_for_type_index(int32_t type_index):
@@ -366,30 +422,25 @@ cdef inline object
make_fallback_cls_for_type_index(int32_t type_index):
# Create `type_info.type_cls` now
class cls(parent_type_info.type_cls):
- pass
- attrs = cls.__dict__.copy()
- attrs.pop("__dict__", None)
- attrs.pop("__weakref__", None)
- attrs.update({
- "__slots__": (),
- "__tvm_ffi_type_info__": type_info,
- "__name__": type_key.split(".")[-1],
- "__qualname__": type_key,
- "__module__": ".".join(type_key.split(".")[:-1]),
- "__doc__": f"Auto-generated fallback class for {type_key}.\n"
- "This class is generated because the class is not
registered.\n"
- "Please do not use this class directly, instead register
the class\n"
- "using `register_object` decorator.",
- })
+ __slots__ = ()
+
+ cls.__tvm_ffi_type_info__ = type_info
+ cls.__name__ = type_key.split(".")[-1]
+ cls.__qualname__ = type_key
+ cls.__module__ = ".".join(type_key.split(".")[:-1])
+ cls.__doc__ = (
+ f"Auto-generated fallback class for {type_key}.\n"
+ "This class is generated because the class is not registered.\n"
+ "Please do not use this class directly, instead register the class\n"
+ "using `register_object` decorator."
+ )
for field in type_info.fields:
- attrs[field.name] = field.as_property(cls)
+ setattr(cls, field.name, field.as_property(cls))
for method in type_info.methods:
name = method.name
if name == "__ffi_init__":
name = "__c_ffi_init__"
- attrs[name] = method.as_callable(cls)
- for name, val in attrs.items():
- setattr(cls, name, val)
+ setattr(cls, name, method.as_callable(cls))
# Update the registry
type_info.type_cls = cls
_update_registry(type_index, type_key, type_info, cls)
@@ -404,7 +455,7 @@ cdef inline object make_ret_object(TVMFFIAny result):
if type_index < len(TYPE_INDEX_TO_CLS) and (cls :=
TYPE_INDEX_TO_CLS[type_index]) is not None:
if issubclass(cls, PyNativeObject):
obj = Object.__new__(Object)
- (<Object>obj).chandle = result.v_obj
+ (<CObject>obj).chandle = result.v_obj
return cls.__from_tvm_ffi_object__(cls, obj)
else:
# Slow path: object is not found in registered entry
@@ -412,7 +463,7 @@ cdef inline object make_ret_object(TVMFFIAny result):
# For every unregistered class, this slow path will be triggered only
once.
cls = make_fallback_cls_for_type_index(type_index)
obj = cls.__new__(cls)
- (<Object>obj).chandle = result.v_obj
+ (<CObject>obj).chandle = result.v_obj
return obj
diff --git a/python/tvm_ffi/cython/string.pxi b/python/tvm_ffi/cython/string.pxi
index 2c23a08f..399c8199 100644
--- a/python/tvm_ffi/cython/string.pxi
+++ b/python/tvm_ffi/cython/string.pxi
@@ -19,12 +19,12 @@
# helper class for string/bytes handling
cdef inline str _string_obj_get_py_str(obj):
- cdef TVMFFIByteArray* bytes =
TVMFFIBytesGetByteArrayPtr((<Object>obj).chandle)
+ cdef TVMFFIByteArray* bytes =
TVMFFIBytesGetByteArrayPtr((<CObject>obj).chandle)
return bytearray_to_str(bytes)
cdef inline bytes _bytes_obj_get_py_bytes(obj):
- cdef TVMFFIByteArray* bytes =
TVMFFIBytesGetByteArrayPtr((<Object>obj).chandle)
+ cdef TVMFFIByteArray* bytes =
TVMFFIBytesGetByteArrayPtr((<CObject>obj).chandle)
return bytearray_to_bytes(bytes)
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 8b78c809..dcf8e19c 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -230,8 +230,8 @@ def from_dlpack(
# helper class for shape handling
-def _shape_obj_get_py_tuple(obj: "Object") -> tuple[int, ...]:
- cdef TVMFFIShapeCell* shape = TVMFFIShapeGetCellPtr((<Object>obj).chandle)
+def _shape_obj_get_py_tuple(obj: "CObject") -> tuple[int, ...]:
+ cdef TVMFFIShapeCell* shape = TVMFFIShapeGetCellPtr((<CObject>obj).chandle)
return tuple(shape.data[i] for i in range(shape.size))
@@ -247,7 +247,7 @@ def _make_strides_from_shape(tuple shape: tuple[int, ...])
-> tuple[int, ...]:
return tuple(reversed(strides))
-cdef class Tensor(Object):
+cdef class Tensor(CObject):
"""Managed n-dimensional array compatible with DLPack.
It provides zero-copy interoperability with array libraries
@@ -268,6 +268,7 @@ cdef class Tensor(Object):
np.testing.assert_equal(np.from_dlpack(x), np.arange(6, dtype="int32"))
"""
+ __slots__ = ()
cdef DLTensor* cdltensor
@property
@@ -433,6 +434,7 @@ cdef DLPackExchangeAPI*
_dltensor_test_wrapper_get_exchange_api() noexcept:
cdef class DLTensorTestWrapper:
"""Wrapper of a Tensor that exposes DLPack protocol, only for testing
purpose.
"""
+ __slots__ = ()
__dlpack_c_exchange_api__ = pycapsule.PyCapsule_New(
_dltensor_test_wrapper_get_exchange_api(),
b"dlpack_exchange_api",
@@ -465,7 +467,7 @@ cdef inline object make_ret_dltensor(TVMFFIAny result):
cdef DLTensor* dltensor
dltensor = <DLTensor*>result.v_ptr
tensor = _CLASS_TENSOR.__new__(_CLASS_TENSOR)
- (<Object>tensor).chandle = NULL
+ (<CObject>tensor).chandle = NULL
(<Tensor>tensor).cdltensor = dltensor
return tensor
@@ -497,7 +499,7 @@ cdef inline object make_tensor_from_chandle(
pass
# default return the tensor
tensor = _CLASS_TENSOR.__new__(_CLASS_TENSOR)
- (<Object>tensor).chandle = chandle
+ (<CObject>tensor).chandle = chandle
(<Tensor>tensor).cdltensor = TVMFFITensorGetDLTensorPtr(chandle)
return tensor
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 050be909..2d665443 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -25,10 +25,10 @@ cdef class FieldGetter:
cdef TVMFFIFieldGetter getter
cdef int64_t offset
- def __call__(self, Object obj):
+ def __call__(self, CObject obj):
cdef TVMFFIAny result
cdef int c_api_ret_code
- cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
+ cdef void* field_ptr = (<char*>(<CObject>obj).chandle) + self.offset
result.type_index = kTVMFFINone
result.v_int64 = 0
c_api_ret_code = self.getter(field_ptr, &result)
@@ -41,9 +41,9 @@ cdef class FieldSetter:
cdef TVMFFIFieldSetter setter
cdef int64_t offset
- def __call__(self, Object obj, value):
+ def __call__(self, CObject obj, value):
cdef int c_api_ret_code
- cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
+ cdef void* field_ptr = (<char*>(<CObject>obj).chandle) + self.offset
TVMFFIPyCallFieldSetter(
TVMFFIPyArgSetterFactory_,
self.setter,
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index ad7dfce5..8bc6bb17 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -113,6 +113,7 @@ class Module(core.Object):
# tvm-ffi-stubgen(end)
entry_name: ClassVar[str] = "main" # constant for entry function name
+ __slots__ = ("__dict__",)
@property
def kind(self) -> str:
@@ -162,10 +163,10 @@ class Module(core.Object):
"""Accessor to allow getting functions as attributes."""
try:
func = self.get_function(name)
- self.__dict__[name] = func
- return func
- except AttributeError:
- raise AttributeError(f"Module has no function '{name}'")
+ except AttributeError as exc:
+ raise AttributeError(f"Module has no function '{name}'") from exc
+ setattr(self, name, func)
+ return func
def get_function(self, name: str, query_imports: bool = False) ->
core.Function:
"""Get function from the module.
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 46074c28..561dd20b 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -18,7 +18,6 @@
from __future__ import annotations
-import functools
import json
import sys
from typing import Any, Callable, Literal, Sequence, TypeVar, overload
@@ -356,7 +355,13 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo)
-> type:
setattr(type_cls, "__init__", getattr(type_cls, "__ffi_init__"))
elif not issubclass(type_cls, core.PyNativeObject):
setattr(type_cls, "__init__", __init__invalid)
- is_container = type_info.type_key in ("ffi.Array", "ffi.Map", "ffi.List",
"ffi.Dict")
+ is_container = type_info.type_key in (
+ "ffi.Array",
+ "ffi.Map",
+ "ffi.List",
+ "ffi.Dict",
+ "ffi.Shape",
+ )
_setup_copy_methods(type_cls, has_shallow_copy, is_container=is_container)
return type_cls
@@ -395,12 +400,9 @@ def _copy_supported(self: Any) -> Any:
def _deepcopy_supported(self: Any, memo: Any = None) -> Any:
- return _get_deep_copy_func()(self)
-
+ from . import _ffi_api # noqa: PLC0415
[email protected]_cache(maxsize=1)
-def _get_deep_copy_func() -> core.Function:
- return get_global_func("ffi.DeepCopy")
+ return _ffi_api.DeepCopy(self)
def _replace_supported(self: Any, **kwargs: Any) -> Any:
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index cee90842..8f563433 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -136,6 +136,17 @@ def test_opaque_type_error() -> None:
)
+def test_object_init() -> None:
+ # Registered class with __c_ffi_init__ should work fine
+ pair = tvm_ffi.testing.TestIntPair(3, 4) # ty:
ignore[too-many-positional-arguments]
+ assert pair.a == 3 and pair.b == 4
+
+ # FFI-returned objects should work fine
+ obj = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=7)
+ assert obj.__chandle__() != 0
+ assert obj.v_i64 == 7 # ty: ignore[unresolved-attribute]
+
+
def test_object_protocol() -> None:
class CompactObject:
def __init__(self, backend_obj: Any) -> None:
@@ -174,6 +185,38 @@ def test_unregistered_object_fallback() -> None:
_check_type(obj)
[email protected](
+ ("test_cls", "make_instance"),
+ [
+ (
+ tvm_ffi.testing.TestObjectBase,
+ lambda: tvm_ffi.testing.create_object("testing.TestObjectBase"),
+ ),
+ (
+ tvm_ffi.testing.TestIntPair,
+ lambda: tvm_ffi.testing.TestIntPair(1, 2), # ty:
ignore[too-many-positional-arguments]
+ ),
+ (
+ tvm_ffi.testing.TestObjectDerived,
+ lambda: tvm_ffi.testing.create_object(
+ "testing.TestObjectDerived",
+ v_i64=20,
+ v_map=tvm_ffi.convert({"a": 1}),
+ v_array=tvm_ffi.convert([1, 2]),
+ ),
+ ),
+ ],
+)
+def test_object_subclass_slots(test_cls: type, make_instance: Any) -> None:
+ slots = test_cls.__dict__.get("__slots__")
+ assert slots == ()
+ assert "__dict__" not in test_cls.__dict__
+ assert "__weakref__" not in test_cls.__dict__
+ obj = make_instance()
+ with pytest.raises(AttributeError):
+ obj._tvm_ffi_test_attr = "nope"
+
+
@pytest.mark.parametrize(
"test_cls, type_key, parent_cls",
[