This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 4fe8b2b [ABI] Further clarify Function ABI in cpp (#71)
4fe8b2b is described below
commit 4fe8b2b79dfeb469b2499acecb3e10038ddcee0f
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Sep 29 14:32:11 2025 -0400
[ABI] Further clarify Function ABI in cpp (#71)
---
include/tvm/ffi/c_api.h | 13 ++++
include/tvm/ffi/function.h | 165 ++++++++++++++++++---------------------------
pyproject.toml | 2 +-
python/tvm_ffi/__init__.py | 2 +-
tests/cpp/test_function.cc | 10 +++
5 files changed, 90 insertions(+), 102 deletions(-)
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 0cac1f7..c688aee 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -434,6 +434,19 @@ typedef int (*TVMFFISafeCallType)(void* handle, const
TVMFFIAny* args, int32_t n
typedef struct {
/*! \brief A C API compatible call with exception catching. */
TVMFFISafeCallType safe_call;
+ /*!
+ * \brief A function pointer to an underlying cpp call.
+ *
+ * The signature is the same as TVMFFISafeCallType except the return type is
void,
+ * and the function throws exception directly instead of returning error
code.
+ * We use void* here to avoid depending on c++ compiler.
+ *
+ * This pointer should be set to NULL for functions that are not originally
created in cpp.
+ *
+ * \note The caller must assume the same cpp exception catching abi when
using this pointer.
+ * When used across FFI boundaries, always use safe_call.
+ */
+ void* cpp_call;
} TVMFFIFunctionCell;
/*!
diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h
index 920856d..ba3ec0d 100644
--- a/include/tvm/ffi/function.h
+++ b/include/tvm/ffi/function.h
@@ -102,9 +102,8 @@ class FunctionObj : public Object, public
TVMFFIFunctionCell {
public:
/*! \brief Typedef for C++ style calling signature that comes with exception
propagation */
typedef void (*FCall)(const FunctionObj*, const AnyView*, int32_t, Any*);
+ using TVMFFIFunctionCell::cpp_call;
using TVMFFIFunctionCell::safe_call;
- /*! \brief A C++ style call implementation, with exception propagation in
C++ style. */
- FCall call;
/*!
* \brief Call the function in packed format.
* \param args The arguments
@@ -112,7 +111,11 @@ class FunctionObj : public Object, public
TVMFFIFunctionCell {
* \param result The return value.
*/
TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any*
result) const {
- this->call(this, args, num_args, result);
+ // if cpp_call is set, use it to call the function, otherwise, redirect to
safe_call
+ // use conditional expression here so the select is branchless
+ FCall call_ptr =
+ this->cpp_call ? reinterpret_cast<FCall>(this->cpp_call) :
CppCallDedirectToSafeCall;
+ (*call_ptr)(this, args, num_args, result);
}
/// \cond Doxygen_Suppress
static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction;
@@ -122,18 +125,15 @@ class FunctionObj : public Object, public
TVMFFIFunctionCell {
protected:
/*! \brief Make default constructor protected. */
FunctionObj() {}
- /// \cond Doxygen_Suppress
- // Implementing safe call style
- static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args,
TVMFFIAny* result) {
- TVM_FFI_SAFE_CALL_BEGIN();
- TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin);
- FunctionObj* self = static_cast<FunctionObj*>(func);
- self->call(self, reinterpret_cast<const AnyView*>(args), num_args,
- reinterpret_cast<Any*>(result));
- TVM_FFI_SAFE_CALL_END();
- }
- /// \endcond
friend class Function;
+
+ private:
+ static void CppCallDedirectToSafeCall(const FunctionObj* func, const
AnyView* args,
+ int32_t num_args, Any* rv) {
+ FunctionObj* self =
static_cast<FunctionObj*>(const_cast<FunctionObj*>(func));
+ TVM_FFI_CHECK_SAFE_CALL(self->safe_call(self, reinterpret_cast<const
TVMFFIAny*>(args),
+ num_args,
reinterpret_cast<TVMFFIAny*>(rv)));
+ }
};
namespace details {
@@ -154,87 +154,66 @@ class FunctionObjImpl : public FunctionObj {
*/
explicit FunctionObjImpl(TCallable callable) : callable_(callable) {
this->safe_call = SafeCall;
- this->call = Call;
+ this->cpp_call = reinterpret_cast<void*>(CppCall);
}
private:
// implementation of call
- static void Call(const FunctionObj* func, const AnyView* args, int32_t
num_args, Any* result) {
+ static void CppCall(const FunctionObj* func, const AnyView* args, int32_t
num_args, Any* result) {
(static_cast<const TSelf*>(func))->callable_(args, num_args, result);
}
-
+ /// \cond Doxygen_Suppress
+ // Implementing safe call style
+ static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args,
TVMFFIAny* result) {
+ TVM_FFI_SAFE_CALL_BEGIN();
+ TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin);
+ FunctionObj* self = static_cast<FunctionObj*>(func);
+ reinterpret_cast<FCall>(self->cpp_call)(self, reinterpret_cast<const
AnyView*>(args), num_args,
+ reinterpret_cast<Any*>(result));
+ TVM_FFI_SAFE_CALL_END();
+ }
+ /// \endcond
/*! \brief Type-erased filed for storing callable object*/
mutable TStorage callable_;
};
/*!
- * \brief Base class to provide a common implementation to redirect call to
safecall
- * \tparam Derived The derived class in CRTP-idiom
+ * \brief FunctionObj specialization for raw C style callback where handle and
deleter are null.
*/
-template <typename Derived>
-struct RedirectCallToSafeCall {
- static void Call(const FunctionObj* func, const AnyView* args, int32_t
num_args, Any* rv) {
- Derived* self = static_cast<Derived*>(const_cast<FunctionObj*>(func));
- TVM_FFI_CHECK_SAFE_CALL(self->RedirectSafeCall(reinterpret_cast<const
TVMFFIAny*>(args),
- num_args,
reinterpret_cast<TVMFFIAny*>(rv)));
- }
-
- static int32_t SafeCall(void* func, const TVMFFIAny* args, int32_t num_args,
TVMFFIAny* rv) {
- Derived* self = reinterpret_cast<Derived*>(func);
- return self->RedirectSafeCall(args, num_args, rv);
+class ExternCFunctionObjNullHandleImpl : public FunctionObj {
+ public:
+ explicit ExternCFunctionObjNullHandleImpl(TVMFFISafeCallType safe_call) {
+ this->safe_call = safe_call;
+ this->cpp_call = nullptr;
}
};
/*!
* \brief FunctionObj specialization that leverages C-style callback
definitions.
*/
-class ExternCFunctionObjImpl : public FunctionObj,
- public
RedirectCallToSafeCall<ExternCFunctionObjImpl> {
+class ExternCFunctionObjImpl : public FunctionObj {
public:
- using RedirectCallToSafeCall<ExternCFunctionObjImpl>::SafeCall;
-
ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void
(*deleter)(void* self))
: self_(self), safe_call_(safe_call), deleter_(deleter) {
- this->call = RedirectCallToSafeCall<ExternCFunctionObjImpl>::Call;
- this->safe_call = RedirectCallToSafeCall<ExternCFunctionObjImpl>::SafeCall;
+ this->safe_call = SafeCall;
+ this->cpp_call = nullptr;
}
- ~ExternCFunctionObjImpl() { deleter_(self_); }
-
- TVM_FFI_INLINE int32_t RedirectSafeCall(const TVMFFIAny* args, int32_t
num_args,
- TVMFFIAny* rv) const {
- return safe_call_(self_, args, num_args, rv);
+ ~ExternCFunctionObjImpl() {
+ if (deleter_) deleter_(self_);
}
private:
+ static int32_t SafeCall(void* func, const TVMFFIAny* args, int32_t num_args,
TVMFFIAny* rv) {
+ ExternCFunctionObjImpl* self =
reinterpret_cast<ExternCFunctionObjImpl*>(func);
+ return self->safe_call_(self->self_, args, num_args, rv);
+ }
+
void* self_;
TVMFFISafeCallType safe_call_;
void (*deleter_)(void* self);
};
-/*!
- * \brief FunctionObj specialization that wraps an external function.
- */
-class ImportedFunctionObjImpl : public FunctionObj,
- public
RedirectCallToSafeCall<ImportedFunctionObjImpl> {
- public:
- using RedirectCallToSafeCall<ImportedFunctionObjImpl>::SafeCall;
-
- explicit ImportedFunctionObjImpl(ObjectPtr<Object> data) : data_(data) {
- this->call = RedirectCallToSafeCall<ImportedFunctionObjImpl>::Call;
- this->safe_call =
RedirectCallToSafeCall<ImportedFunctionObjImpl>::SafeCall;
- }
-
- TVM_FFI_INLINE int32_t RedirectSafeCall(const TVMFFIAny* args, int32_t
num_args,
- TVMFFIAny* rv) const {
- FunctionObj* func = const_cast<FunctionObj*>(static_cast<const
FunctionObj*>(data_.get()));
- return func->safe_call(func, args, num_args, rv);
- }
-
- private:
- ObjectPtr<Object> data_;
-};
-
// Helper class to set packed arguments
class PackedArgsSetter {
public:
@@ -353,31 +332,13 @@ class Function : public ObjectRef {
return FromPackedInternal(packed_call);
}
}
+
/*!
- * \brief Import a possibly externally defined function to this dll
- * \param other Function defined in another dynamic library.
+ * \brief Create ffi::Function from a C style callbacks.
*
- * \note This function will redirect the call to safe_call in other.
- * It will try to detect if the function is already from the same DLL
- * and directly return the original function if so.
+ * self and deleter can be nullptr if the function do not need closure
support
+ * and corresponds to a raw function pointer.
*
- * \return The imported function.
- */
- static Function ImportFromExternDLL(Function other) {
- const FunctionObj* other_func = static_cast<const
FunctionObj*>(other.get());
- // the other function comes from the same dll, no action needed
- if (other_func->safe_call == &(FunctionObj::SafeCall) ||
- other_func->safe_call == &(details::ImportedFunctionObjImpl::SafeCall)
||
- other_func->safe_call == &(details::ExternCFunctionObjImpl::SafeCall))
{
- return other;
- }
- // the other function coems from a different library
- Function func;
- func.data_ =
make_object<details::ImportedFunctionObjImpl>(std::move(other.data_));
- return func;
- }
- /*!
- * \brief Create ffi::Function from a C style callbacks.
* \param self Resource handle to the function
* \param safe_call The safe_call definition in C.
* \param deleter The deleter to release the resource of self.
@@ -387,7 +348,11 @@ class Function : public ObjectRef {
void (*deleter)(void* self)) {
// the other function coems from a different library
Function func;
- func.data_ = make_object<details::ExternCFunctionObjImpl>(self, safe_call,
deleter);
+ if (self == nullptr && deleter == nullptr) {
+ func.data_ =
make_object<details::ExternCFunctionObjNullHandleImpl>(safe_call);
+ } else {
+ func.data_ = make_object<details::ExternCFunctionObjImpl>(self,
safe_call, deleter);
+ }
return func;
}
/*!
@@ -854,19 +819,19 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) {
*
* \note The final symbol name is `__tvm_ffi_<ExportName>`.
*/
-#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function)
\
- extern "C" {
\
- TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, TVMFFIAny* args,
int32_t num_args, \
- TVMFFIAny* result) {
\
- TVM_FFI_SAFE_CALL_BEGIN();
\
- using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>;
\
- static std::string name = #ExportName;
\
- ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>(
\
- std::make_index_sequence<FuncInfo::num_args>{}, &name, Function,
\
- reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args,
\
- reinterpret_cast<::tvm::ffi::Any*>(result));
\
- TVM_FFI_SAFE_CALL_END();
\
- }
\
+#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function)
\
+ extern "C" {
\
+ TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, const TVMFFIAny*
args, \
+ int32_t num_args, TVMFFIAny*
result) { \
+ TVM_FFI_SAFE_CALL_BEGIN();
\
+ using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>;
\
+ static std::string name = #ExportName;
\
+ ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>(
\
+ std::make_index_sequence<FuncInfo::num_args>{}, &name, Function,
\
+ reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args,
\
+ reinterpret_cast<::tvm::ffi::Any*>(result));
\
+ TVM_FFI_SAFE_CALL_END();
\
+ }
\
}
} // namespace ffi
} // namespace tvm
diff --git a/pyproject.toml b/pyproject.toml
index 4a6a734..ff52507 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
[project]
name = "apache-tvm-ffi"
-version = "0.1.0b11"
+version = "0.1.0b12"
description = "tvm ffi"
authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index d88a35e..9a9ed83 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -17,7 +17,7 @@
"""TVM FFI Python package."""
# version
-__version__ = "0.1.0b11"
+__version__ = "0.1.0b12"
# order matters here so we need to skip isort here
# isort: skip_file
diff --git a/tests/cpp/test_function.cc b/tests/cpp/test_function.cc
index c3c484f..674bb75 100644
--- a/tests/cpp/test_function.cc
+++ b/tests/cpp/test_function.cc
@@ -236,4 +236,14 @@ TEST(Func, ObjectRefWithFallbackTraits) {
::tvm::ffi::Error);
}
+int testing_add1(int x) { return x + 1; }
+
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(testing_add1, testing_add1);
+
+TEST(Func, FromExternC) {
+ // this is the function abi convention
+ Function fadd1 = Function::FromExternC(nullptr, __tvm_ffi_testing_add1,
nullptr);
+ EXPECT_EQ(fadd1(1).cast<int>(), 2);
+}
+
} // namespace