This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new c33f1ecb feat(ffi)!: add __ffi_convert__ type attribute and
Python-defined type support (#503)
c33f1ecb is described below
commit c33f1ecbc1bf3ba5da1bf28d57fc30ba25a0f3d6
Author: Junru Shao <[email protected]>
AuthorDate: Wed Mar 18 22:00:40 2026 -0700
feat(ffi)!: add __ffi_convert__ type attribute and Python-defined type
support (#503)
## Summary
- Introduce `__ffi_convert__` type attribute: a per-type `Function` that
performs typed `AnyView -> TObjectRef` conversion via
`TypeTraits::TryCastFromAnyView`, registered through
`TypeAttrDef<T>()`
- Add Python-defined type infrastructure in `dataclass.cc`:
`PyClassDeleter`, `MakeFFINew` (calloc-based factory), field
getter/setter dispatch, and `__ffi_shallow_copy__` for deep-copy support
- Register `__ffi_convert__` for all built-in static types (Object,
String, Bytes, Error, Function, Shape, Tensor, Array, Map, List, Dict)
- Expose `ffi.FunctionFromExternC` global function for Python-side
construction of `Function` objects from raw C function pointers
## Test plan
- [x] Existing C++ tests pass (reflection, object creation,
serialization)
- [x] Existing Python tests pass (c_class / py_class decorator
integration)
- [x] CI lint (clang-format, clang-tidy) passes on changed files
- [x] Integration testing with Python-side py_class decorator
**BREAKING CHANGE:** Object base type registration now explicitly sets
`structural_eq_hash_kind` to `kTVMFFISEqHashKindUnsupported`, which may
affect types that previously inherited an uninitialized default.
---
include/tvm/ffi/reflection/init.h | 21 +++
include/tvm/ffi/reflection/registry.h | 2 +
src/ffi/extra/dataclass.cc | 254 ++++++++++++++++++++++++++++++++++
src/ffi/function.cc | 21 ++-
src/ffi/object.cc | 31 ++++-
src/ffi/testing/testing.cc | 2 +
6 files changed, 321 insertions(+), 10 deletions(-)
diff --git a/include/tvm/ffi/reflection/init.h
b/include/tvm/ffi/reflection/init.h
index 8853f054..009d1ac4 100644
--- a/include/tvm/ffi/reflection/init.h
+++ b/include/tvm/ffi/reflection/init.h
@@ -43,6 +43,27 @@ namespace tvm {
namespace ffi {
namespace reflection {
+namespace details {
+
+/*!
+ * \brief Convert an AnyView to a specific reflected object type.
+ *
+ * \tparam TObjectRef The object reference type to convert to.
+ * \param input The AnyView to convert.
+ * \return The converted object.
+ */
+template <typename TObjectRef>
+TObjectRef FFIConvertFromAnyViewToObjectRef(AnyView input) {
+ TVMFFIAny input_pod = input.CopyToTVMFFIAny();
+ if (auto opt = TypeTraits<TObjectRef>::TryCastFromAnyView(&input_pod)) {
+ return *std::move(opt);
+ }
+ TVM_FFI_THROW(TypeError) << "Cannot cast from `" <<
TypeIndexToTypeKey(input_pod.type_index)
+ << "` to `" << TypeTraits<TObjectRef>::TypeStr() <<
"`";
+}
+
+} // namespace details
+
/*!
* \brief Create a packed ``__ffi_init__`` constructor for the given type.
*
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index 17c0078b..92379a1d 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -693,6 +693,8 @@ inline constexpr const char* kHash = "__ffi_hash__";
inline constexpr const char* kEq = "__ffi_eq__";
/*! \brief Type attribute for custom recursive three-way comparison. */
inline constexpr const char* kCompare = "__ffi_compare__";
+/*! \brief Type attribute for converting AnyView to a specific reflected
object type. */
+inline constexpr const char* kConvert = "__ffi_convert__";
} // namespace type_attr
/*!
diff --git a/src/ffi/extra/dataclass.cc b/src/ffi/extra/dataclass.cc
index 9ea64cf2..8b5eebba 100644
--- a/src/ffi/extra/dataclass.cc
+++ b/src/ffi/extra/dataclass.cc
@@ -1662,6 +1662,254 @@ class RecursiveComparer : public
ObjectGraphDFS<RecursiveComparer, CompareFrame,
}
};
+// ---------- Python-defined type support ----------
+
+/*!
+ * \brief Deleter for objects whose layout is defined from Python via Field
descriptors.
+ *
+ * For the "strong" phase, iterates all reflected fields and destructs
+ * Any/ObjectRef values in-place (to release references). For the "weak"
+ * phase, frees the underlying calloc'd memory.
+ */
+void PyClassDeleter(void* self_void, int flags) {
+ TVMFFIObject* self = static_cast<TVMFFIObject*>(self_void);
+ if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) {
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index);
+ reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* finfo) {
+ void* field_addr = reinterpret_cast<char*>(self) + finfo->offset;
+ int32_t ti = finfo->field_static_type_index;
+ if (ti == TypeIndex::kTVMFFIAny) {
+ // Any field: call destructor to release owned references
+ reinterpret_cast<Any*>(field_addr)->~Any();
+ } else if (ti >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ // ObjectRef field: call destructor to DecRef
+ reinterpret_cast<ObjectRef*>(field_addr)->~ObjectRef();
+ }
+ // POD types (int, float, bool, etc.): no cleanup needed
+ });
+ }
+ if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) {
+ std::free(self_void);
+ }
+}
+
+/*!
+ * \brief Generic field getter for Python-defined types.
+ *
+ * Reads a value of type T from the given field address and packs it into
+ * a TVMFFIAny result.
+ *
+ * \tparam T The C++ type stored at the field address.
+ */
+template <typename T>
+int PyClassFieldGetter(void* field, TVMFFIAny* result) {
+ TVM_FFI_SAFE_CALL_BEGIN();
+ *result =
details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
+ TVM_FFI_SAFE_CALL_END();
+}
+
+/*!
+ * \brief Return the TVMFFIFieldGetter function pointer for a given field type
index.
+ *
+ * \param field_type_index The static type index of the field.
+ * \return The function pointer as int64_t for FFI transport.
+ */
+int64_t GetFieldGetter(int32_t field_type_index) {
+ TVMFFIFieldGetter getter = nullptr;
+ switch (field_type_index) {
+ case TypeIndex::kTVMFFIInt:
+ getter = &PyClassFieldGetter<int64_t>;
+ break;
+ case TypeIndex::kTVMFFIFloat:
+ getter = &PyClassFieldGetter<double>;
+ break;
+ case TypeIndex::kTVMFFIBool:
+ getter = &PyClassFieldGetter<bool>;
+ break;
+ case TypeIndex::kTVMFFIOpaquePtr:
+ getter = &PyClassFieldGetter<void*>;
+ break;
+ case TypeIndex::kTVMFFIDataType:
+ getter = &PyClassFieldGetter<DLDataType>;
+ break;
+ case TypeIndex::kTVMFFIDevice:
+ getter = &PyClassFieldGetter<DLDevice>;
+ break;
+ default:
+ if (field_type_index == TypeIndex::kTVMFFIAny || field_type_index ==
TypeIndex::kTVMFFINone) {
+ getter = &PyClassFieldGetter<Any>;
+ } else if (field_type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ getter = &PyClassFieldGetter<ObjectRef>;
+ } else {
+ TVM_FFI_THROW(ValueError) << "Unsupported field type index for getter:
"
+ << field_type_index;
+ }
+ break;
+ }
+ return reinterpret_cast<int64_t>(getter);
+}
+
+/*!
+ * \brief Write a converted value to a field of the appropriate C++ type.
+ *
+ * Dispatches on field_type_index to reinterpret the destination address and
+ * assign from the converted Any value.
+ */
+void WriteFieldValue(void* field_addr, int32_t field_type_index, Any value) {
+ switch (field_type_index) {
+ case TypeIndex::kTVMFFIInt:
+ *reinterpret_cast<int64_t*>(field_addr) = value.cast<int64_t>();
+ return;
+ case TypeIndex::kTVMFFIFloat:
+ *reinterpret_cast<double*>(field_addr) = value.cast<double>();
+ return;
+ case TypeIndex::kTVMFFIBool:
+ *reinterpret_cast<bool*>(field_addr) = value.cast<bool>();
+ return;
+ case TypeIndex::kTVMFFIOpaquePtr:
+ *reinterpret_cast<void**>(field_addr) = value.cast<void*>();
+ return;
+ case TypeIndex::kTVMFFIDataType:
+ *reinterpret_cast<DLDataType*>(field_addr) = value.cast<DLDataType>();
+ return;
+ case TypeIndex::kTVMFFIDevice:
+ *reinterpret_cast<DLDevice*>(field_addr) = value.cast<DLDevice>();
+ return;
+ default:
+ break;
+ }
+ if (field_type_index == TypeIndex::kTVMFFIAny || field_type_index ==
TypeIndex::kTVMFFINone) {
+ *reinterpret_cast<Any*>(field_addr) = std::move(value);
+ } else if (field_type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ *reinterpret_cast<ObjectRef*>(field_addr) = value.cast<ObjectRef>();
+ } else {
+ TVM_FFI_THROW(ValueError) << "Unsupported field type index for setter: "
<< field_type_index;
+ }
+}
+
+/*!
+ * \brief Create a FunctionObj setter for a Python-defined field.
+ *
+ * The returned Function accepts (OpaquePtr field_addr, AnyView value),
+ * calls f_convert to coerce the value via the type_converter, and writes
+ * the result to the field.
+ *
+ * \param field_type_index The static type index of the field.
+ * \param type_converter_int Opaque pointer (as int64_t) to the Python
_TypeConverter (borrowed).
+ * \param f_convert_int C function pointer (as int64_t): int(void*, const
TVMFFIAny*, TVMFFIAny*).
+ * Returns 0 on success, -1 on error (error stored in TLS).
+ * \return A packed Function suitable for use as a FunctionObj setter.
+ */
+Function MakeFieldSetter(int32_t field_type_index, int64_t type_converter_int,
+ int64_t f_convert_int) {
+ // NOLINTNEXTLINE(performance-no-int-to-ptr)
+ void* type_converter = reinterpret_cast<void*>(type_converter_int);
+ using FConvert = int (*)(void*, const TVMFFIAny*, TVMFFIAny*);
+ // NOLINTNEXTLINE(performance-no-int-to-ptr)
+ auto f_convert = reinterpret_cast<FConvert>(f_convert_int);
+
+ return Function::FromPacked([field_type_index, type_converter, f_convert](
+ const AnyView* args, int32_t num_args, Any*
rv) {
+ void* field_addr = args[0].cast<void*>();
+ // Call the Cython-level type converter via C function pointer.
+ TVMFFIAny converted;
+ converted.type_index = TypeIndex::kTVMFFINone;
+ converted.v_int64 = 0;
+ int err = f_convert(type_converter, reinterpret_cast<const
TVMFFIAny*>(&args[1]), &converted);
+ if (err != 0) {
+ throw details::MoveFromSafeCallRaised();
+ }
+ // Take ownership of the converted value and write to the field.
+ Any owned = details::AnyUnsafe::MoveTVMFFIAnyToAny(&converted);
+ WriteFieldValue(field_addr, field_type_index, std::move(owned));
+ });
+}
+
+/*!
+ * \brief Register a ``__ffi_new__`` type attribute for a Python-defined type.
+ *
+ * Creates a factory Function that allocates zero-initialized memory of the
+ * given size, sets up the TVMFFIObject header (type_index, ref counts,
+ * deleter), and returns an ObjectRef. Also registers this factory as the
+ * ``__ffi_new__`` type attribute so that ``CreateEmptyObject`` can find it.
+ *
+ * \param type_index The type index of the Python-defined type.
+ * \param total_size The total object size in bytes (header + fields).
+ */
+void MakeFFINew(int32_t type_index, int32_t total_size) {
+ // Pre-compute type_info pointer (stable for the process lifetime).
+ // Used by the shallow-copy lambda below; new_fn doesn't need it since
+ // calloc zero-initialization suffices (no placement construction needed).
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
+ Function new_fn = Function::FromTyped([type_index, total_size]() ->
ObjectRef {
+ void* obj_ptr = std::calloc(1, static_cast<size_t>(total_size));
+ if (!obj_ptr) {
+ TVM_FFI_THROW(RuntimeError) << "Failed to allocate " << total_size << "
bytes for type "
+ << TypeIndexToTypeKey(type_index);
+ }
+ TVMFFIObject* ffi_obj = reinterpret_cast<TVMFFIObject*>(obj_ptr);
+ ffi_obj->type_index = type_index;
+ ffi_obj->combined_ref_count = details::kCombinedRefCountBothOne;
+ ffi_obj->deleter = PyClassDeleter;
+ // calloc zero-initializes all bytes. For non-trivial field types:
+ // - Any: zero state is {type_index=kTVMFFINone, v_int64=0},
representing None.
+ // - ObjectRef: zero state is a null pointer.
+ // Both are valid initial states whose destructors and assignment operators
+ // handle correctly, so no placement construction is needed.
+ Object* obj = reinterpret_cast<Object*>(obj_ptr);
+ return ObjectRef(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(obj));
+ });
+ // Register as __ffi_new__ type attribute
+ reflection::EnsureTypeAttrColumn("__ffi_new__");
+ TVMFFIByteArray attr_name = {"__ffi_new__", 11};
+ TVMFFIAny attr_value = AnyView(new_fn).CopyToTVMFFIAny();
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index, &attr_name,
&attr_value));
+ // Register __ffi_shallow_copy__ for deep-copy support.
+ // The shallow copy allocates a new object and copies all fields by value
+ // (IncRef-ing any ObjectRef/Any fields).
+ Function copy_fn =
+ Function::FromTyped([type_index, total_size, type_info](const Object*
src) -> ObjectRef {
+ void* obj_ptr = std::calloc(1, static_cast<size_t>(total_size));
+ if (!obj_ptr) {
+ TVM_FFI_THROW(RuntimeError) << "Failed to allocate for shallow copy";
+ }
+ TVMFFIObject* ffi_obj = reinterpret_cast<TVMFFIObject*>(obj_ptr);
+ ffi_obj->type_index = type_index;
+ ffi_obj->combined_ref_count = details::kCombinedRefCountBothOne;
+ ffi_obj->deleter = PyClassDeleter;
+ reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo*
finfo) {
+ void* dst = reinterpret_cast<char*>(obj_ptr) + finfo->offset;
+ const void* field_src = reinterpret_cast<const char*>(src) +
finfo->offset;
+ int32_t ti = finfo->field_static_type_index;
+ if (ti == TypeIndex::kTVMFFIAny) {
+ new (dst) Any(*reinterpret_cast<const Any*>(field_src));
+ } else if (ti >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ new (dst) ObjectRef(*reinterpret_cast<const
ObjectRef*>(field_src));
+ } else {
+ // POD: memcpy
+ std::memcpy(dst, field_src, static_cast<size_t>(finfo->size));
+ }
+ });
+ Object* obj = reinterpret_cast<Object*>(obj_ptr);
+ return
ObjectRef(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(obj));
+ });
+ // Register as type attribute for generic deep copy lookup
+ reflection::EnsureTypeAttrColumn(reflection::type_attr::kShallowCopy);
+ TVMFFIByteArray copy_attr_name = {
+ reflection::type_attr::kShallowCopy,
+ std::char_traits<char>::length(reflection::type_attr::kShallowCopy)};
+ TVMFFIAny copy_attr_value = AnyView(copy_fn).CopyToTVMFFIAny();
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index, ©_attr_name,
©_attr_value));
+ // Also register as an instance method so Python can call
__ffi_shallow_copy__
+ TVMFFIMethodInfo copy_method;
+ copy_method.name = copy_attr_name;
+ copy_method.doc = TVMFFIByteArray{nullptr, 0};
+ copy_method.flags = 0;
+ copy_method.method = AnyView(copy_fn).CopyToTVMFFIAny();
+ copy_method.metadata = TVMFFIByteArray{nullptr, 0};
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index, ©_method));
+}
+
} // namespace
// ============================================================================
@@ -1725,6 +1973,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
// MakeInit
refl::GlobalDef().def("ffi.MakeInit", refl::MakeInit);
+ // Python-defined type support
+ refl::EnsureTypeAttrColumn("__ffi_new__");
+ refl::GlobalDef().def("ffi.GetFieldGetter", GetFieldGetter);
+ refl::GlobalDef().def("ffi.MakeFieldSetter", MakeFieldSetter);
+ refl::GlobalDef().def("ffi.MakeFFINew", MakeFFINew);
+ refl::GlobalDef().def("ffi.RegisterAutoInit", refl::RegisterAutoInit);
// Deep copy
refl::EnsureTypeAttrColumn(refl::type_attr::kShallowCopy);
refl::GlobalDef().def("ffi.DeepCopy", DeepCopy);
diff --git a/src/ffi/function.cc b/src/ffi/function.cc
index 4f83d28a..4b378f74 100644
--- a/src/ffi/function.cc
+++ b/src/ffi/function.cc
@@ -229,11 +229,18 @@ TVM_FFI_STATIC_INIT_BLOCK() {
})
.def("ffi.String", [](tvm::ffi::String val) -> tvm::ffi::String { return
val; })
.def("ffi.Bytes", [](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { return
val; })
- .def("ffi.GetGlobalFuncMetadata", [](const tvm::ffi::String& name) ->
tvm::ffi::String {
- const auto* f = tvm::ffi::GlobalFunctionTable::Global()->Get(name);
- if (f == nullptr) {
- TVM_FFI_THROW(RuntimeError) << "Global Function is not found: " <<
name;
- }
- return f->metadata_data;
- });
+ .def("ffi.GetGlobalFuncMetadata",
+ [](const tvm::ffi::String& name) -> tvm::ffi::String {
+ const auto* f =
tvm::ffi::GlobalFunctionTable::Global()->Get(name);
+ if (f == nullptr) {
+ TVM_FFI_THROW(RuntimeError) << "Global Function is not found: "
<< name;
+ }
+ return f->metadata_data;
+ })
+ .def("ffi.FunctionFromExternC",
+ [](void* self, void* safe_call, void* deleter) ->
tvm::ffi::Function {
+ return tvm::ffi::Function::FromExternC(self,
+
reinterpret_cast<TVMFFISafeCallType>(safe_call),
+ reinterpret_cast<void
(*)(void*)>(deleter));
+ });
}
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index 011af14e..2429b869 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -23,7 +23,10 @@
#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/memory.h>
@@ -33,10 +36,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
-#include <algorithm>
#include <memory>
-#include <string_view>
-#include <unordered_map>
#include <utility>
#include <vector>
@@ -368,6 +368,7 @@ class TypeTable {
-1);
TVMFFITypeMetadata info;
info.total_size = sizeof(Object);
+ info.structural_eq_hash_kind = kTVMFFISEqHashKindUnsupported;
info.creator = nullptr;
info.doc = TVMFFIByteArray{nullptr, 0};
RegisterTypeMetadata(Object::_type_index, &info);
@@ -596,6 +597,30 @@ namespace {
TVM_FFI_STATIC_INIT_BLOCK() {
using namespace tvm::ffi;
namespace refl = tvm::ffi::reflection;
+ refl::TypeAttrDef<Object>().def(refl::type_attr::kConvert,
+
&refl::details::FFIConvertFromAnyViewToObjectRef<ObjectRef>);
+ refl::TypeAttrDef<details::StringObj>().def(
+ refl::type_attr::kConvert,
&refl::details::FFIConvertFromAnyViewToObjectRef<String>);
+ refl::TypeAttrDef<details::BytesObj>().def(
+ refl::type_attr::kConvert,
&refl::details::FFIConvertFromAnyViewToObjectRef<Bytes>);
+ refl::TypeAttrDef<ErrorObj>().def(refl::type_attr::kConvert,
+
&refl::details::FFIConvertFromAnyViewToObjectRef<Error>);
+ refl::TypeAttrDef<FunctionObj>().def(refl::type_attr::kConvert,
+
&refl::details::FFIConvertFromAnyViewToObjectRef<Function>);
+ refl::TypeAttrDef<ShapeObj>().def(refl::type_attr::kConvert,
+
&refl::details::FFIConvertFromAnyViewToObjectRef<Shape>);
+ refl::TypeAttrDef<TensorObj>().def(refl::type_attr::kConvert,
+
&refl::details::FFIConvertFromAnyViewToObjectRef<Tensor>);
+ refl::TypeAttrDef<ArrayObj>().def(refl::type_attr::kConvert,
+
&refl::details::FFIConvertFromAnyViewToObjectRef<Array<Any>>);
+ refl::TypeAttrDef<MapObj>().def(refl::type_attr::kConvert,
+
&refl::details::FFIConvertFromAnyViewToObjectRef<Map<Any, Any>>);
+ // Skipped: TypeIndex::kTVMFFIModule
+ // Skipped: TypeIndex::kTVMFFIOpaquePyObject
+ refl::TypeAttrDef<ListObj>().def(refl::type_attr::kConvert,
+
&refl::details::FFIConvertFromAnyViewToObjectRef<List<Any>>);
+ refl::TypeAttrDef<DictObj>().def(
+ refl::type_attr::kConvert,
&refl::details::FFIConvertFromAnyViewToObjectRef<Dict<Any, Any>>);
refl::GlobalDef()
.def_method(
"ffi.GetRegisteredTypeKeys",
diff --git a/src/ffi/testing/testing.cc b/src/ffi/testing/testing.cc
index 111e6447..5d4373f3 100644
--- a/src/ffi/testing/testing.cc
+++ b/src/ffi/testing/testing.cc
@@ -76,6 +76,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_ro("a", &TestIntPairObj::a, "Field `a`")
.def_ro("b", &TestIntPairObj::b, "Field `b`")
.def("sum", &TestIntPair::Sum, "Method to compute sum of a and b");
+ refl::TypeAttrDef<TestIntPairObj>().def(
+ refl::type_attr::kConvert,
&refl::details::FFIConvertFromAnyViewToObjectRef<TestIntPair>);
}
class TestObjectBase : public Object {