This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9cb6705f99 [FFI] Enhance FFI Object exception safety during init
(#18050)
9cb6705f99 is described below
commit 9cb6705f99859ee1ba2c81f3f7dd89f9c3cbd6a2
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 8 10:27:39 2025 -0400
[FFI] Enhance FFI Object exception safety during init (#18050)
---
python/tvm/ffi/container.py | 6 ++++++
python/tvm/ffi/cython/object.pxi | 15 ++++++++++++++-
python/tvm/runtime/ndarray.py | 3 +++
tests/python/ffi/test_container.py | 14 ++++++++++++++
4 files changed, 37 insertions(+), 1 deletion(-)
diff --git a/python/tvm/ffi/container.py b/python/tvm/ffi/container.py
index 6ababe2557..66038976f5 100644
--- a/python/tvm/ffi/container.py
+++ b/python/tvm/ffi/container.py
@@ -78,6 +78,9 @@ class Array(core.Object, collections.abc.Sequence):
return _ffi_api.ArraySize(self)
def __repr__(self):
+ # exception safety handling for chandle=None
+ if self.__chandle__() == 0:
+ return type(self).__name__ + "(chandle=None)"
return "[" + ", ".join([x.__repr__() for x in self]) + "]"
@@ -197,4 +200,7 @@ class Map(core.Object, collections.abc.Mapping):
return self[key] if key in self else default
def __repr__(self):
+ # exception safety handling for chandle=None
+ if self.__chandle__() == 0:
+ return type(self).__name__ + "(chandle=None)"
return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in
self.items()]) + "}"
diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi
index f971ca8f5a..4efedf35d8 100644
--- a/python/tvm/ffi/cython/object.pxi
+++ b/python/tvm/ffi/cython/object.pxi
@@ -85,9 +85,15 @@ cdef class Object:
"""
cdef void* chandle
+ def __cinit__(self):
+ # initialize chandle to NULL to avoid leak in
+ # case of error before chandle is set
+ self.chandle = NULL
+
def __dealloc__(self):
if self.chandle != NULL:
CHECK_CALL(TVMFFIObjectFree(self.chandle))
+ self.chandle = NULL
def __ctypes_handle__(self):
return ctypes_handle(self.chandle)
@@ -116,16 +122,23 @@ cdef class Object:
self.chandle = NULL
def __getattr__(self, name):
+ if self.chandle == NULL:
+ raise AttributeError(f"{type(self)} has no attribute {name}")
try:
return __object_getattr__(self, name)
except AttributeError:
raise AttributeError(f"{type(self)} has no attribute {name}")
def __dir__(self):
+ # exception safety handling for chandle=None
+ if self.chandle == NULL:
+ return []
return __object_dir__(self)
def __repr__(self):
- # make sure repr is a raw string
+ # exception safety handling for chandle=None
+ if self.chandle == NULL:
+ return type(self).__name__ + "(chandle=None)"
return str(__object_repr__(self))
def __eq__(self, other):
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 9d49d9c51d..538fa15c8a 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -164,6 +164,9 @@ class NDArray(tvm.ffi.core.NDArray):
return self
def __repr__(self):
+ # exception safety handling for chandle=None
+ if self.__chandle__() == 0:
+ return type(self).__name__ + "(chandle=None)"
res = f"<tvm.nd.NDArray shape={self.shape}, {self.device}>\n"
res += self.numpy().__repr__()
return res
diff --git a/tests/python/ffi/test_container.py
b/tests/python/ffi/test_container.py
index b20c221b4f..5ac3af1799 100644
--- a/tests/python/ffi/test_container.py
+++ b/tests/python/ffi/test_container.py
@@ -27,6 +27,20 @@ def test_array():
assert (a_slice[0], a_slice[1]) == (1, 2)
+def test_bad_constructor_init_state():
+ """Test when error is raised before __init_handle_by_constructor
+
+ This case we need the FFI binding to gracefully handle both repr
+ and dealloc by ensuring the chandle is initialized and there is
+ proper repr code
+ """
+ with pytest.raises(TypeError):
+ tvm_ffi.Array(1)
+
+ with pytest.raises(AttributeError):
+ tvm_ffi.Map(1)
+
+
def test_array_of_array_map():
a = tvm_ffi.convert([[1, 2, 3], {"A": 5, "B": 6}])
assert isinstance(a, tvm_ffi.Array)