This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 4fe8b2b  [ABI] Further clarify Function ABI in cpp (#71)
4fe8b2b is described below

commit 4fe8b2b79dfeb469b2499acecb3e10038ddcee0f
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Sep 29 14:32:11 2025 -0400

    [ABI] Further clarify Function ABI in cpp (#71)
---
 include/tvm/ffi/c_api.h    |  13 ++++
 include/tvm/ffi/function.h | 165 ++++++++++++++++++---------------------------
 pyproject.toml             |   2 +-
 python/tvm_ffi/__init__.py |   2 +-
 tests/cpp/test_function.cc |  10 +++
 5 files changed, 90 insertions(+), 102 deletions(-)

diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 0cac1f7..c688aee 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -434,6 +434,19 @@ typedef int (*TVMFFISafeCallType)(void* handle, const 
TVMFFIAny* args, int32_t n
 typedef struct {
   /*! \brief A C API compatible call with exception catching. */
   TVMFFISafeCallType safe_call;
+  /*!
+   * \brief A function pointer to an underlying cpp call.
+   *
+   * The signature is the same as TVMFFISafeCallType except the return type is 
void,
+   * and the function throws exception directly instead of returning error 
code.
+   * We use void* here to avoid depending on c++ compiler.
+   *
+   * This pointer should be set to NULL for functions that are not originally 
created in cpp.
+   *
+   * \note The caller must assume the same cpp exception catching abi when 
using this pointer.
+   *       When used across FFI boundaries, always use safe_call.
+   */
+  void* cpp_call;
 } TVMFFIFunctionCell;
 
 /*!
diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h
index 920856d..ba3ec0d 100644
--- a/include/tvm/ffi/function.h
+++ b/include/tvm/ffi/function.h
@@ -102,9 +102,8 @@ class FunctionObj : public Object, public 
TVMFFIFunctionCell {
  public:
   /*! \brief Typedef for C++ style calling signature that comes with exception 
propagation */
   typedef void (*FCall)(const FunctionObj*, const AnyView*, int32_t, Any*);
+  using TVMFFIFunctionCell::cpp_call;
   using TVMFFIFunctionCell::safe_call;
-  /*! \brief A C++ style call implementation, with exception propagation in 
C++ style. */
-  FCall call;
   /*!
    * \brief Call the function in packed format.
    * \param args The arguments
@@ -112,7 +111,11 @@ class FunctionObj : public Object, public 
TVMFFIFunctionCell {
    * \param result The return value.
    */
   TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* 
result) const {
-    this->call(this, args, num_args, result);
+    // if cpp_call is set, use it to call the function, otherwise, redirect to 
safe_call
+    // use conditional expression here so the select is branchless
+    FCall call_ptr =
+        this->cpp_call ? reinterpret_cast<FCall>(this->cpp_call) : 
CppCallDedirectToSafeCall;
+    (*call_ptr)(this, args, num_args, result);
   }
   /// \cond Doxygen_Suppress
   static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction;
@@ -122,18 +125,15 @@ class FunctionObj : public Object, public 
TVMFFIFunctionCell {
  protected:
   /*! \brief Make default constructor protected. */
   FunctionObj() {}
-  /// \cond Doxygen_Suppress
-  // Implementing safe call style
-  static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, 
TVMFFIAny* result) {
-    TVM_FFI_SAFE_CALL_BEGIN();
-    TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin);
-    FunctionObj* self = static_cast<FunctionObj*>(func);
-    self->call(self, reinterpret_cast<const AnyView*>(args), num_args,
-               reinterpret_cast<Any*>(result));
-    TVM_FFI_SAFE_CALL_END();
-  }
-  /// \endcond
   friend class Function;
+
+ private:
+  static void CppCallDedirectToSafeCall(const FunctionObj* func, const 
AnyView* args,
+                                        int32_t num_args, Any* rv) {
+    FunctionObj* self = 
static_cast<FunctionObj*>(const_cast<FunctionObj*>(func));
+    TVM_FFI_CHECK_SAFE_CALL(self->safe_call(self, reinterpret_cast<const 
TVMFFIAny*>(args),
+                                            num_args, 
reinterpret_cast<TVMFFIAny*>(rv)));
+  }
 };
 
 namespace details {
@@ -154,87 +154,66 @@ class FunctionObjImpl : public FunctionObj {
    */
   explicit FunctionObjImpl(TCallable callable) : callable_(callable) {
     this->safe_call = SafeCall;
-    this->call = Call;
+    this->cpp_call = reinterpret_cast<void*>(CppCall);
   }
 
  private:
   // implementation of call
-  static void Call(const FunctionObj* func, const AnyView* args, int32_t 
num_args, Any* result) {
+  static void CppCall(const FunctionObj* func, const AnyView* args, int32_t 
num_args, Any* result) {
     (static_cast<const TSelf*>(func))->callable_(args, num_args, result);
   }
-
+  /// \cond Doxygen_Suppress
+  // Implementing safe call style
+  static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, 
TVMFFIAny* result) {
+    TVM_FFI_SAFE_CALL_BEGIN();
+    TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin);
+    FunctionObj* self = static_cast<FunctionObj*>(func);
+    reinterpret_cast<FCall>(self->cpp_call)(self, reinterpret_cast<const 
AnyView*>(args), num_args,
+                                            reinterpret_cast<Any*>(result));
+    TVM_FFI_SAFE_CALL_END();
+  }
+  /// \endcond
   /*! \brief Type-erased filed for storing callable object*/
   mutable TStorage callable_;
 };
 
 /*!
- * \brief Base class to provide a common implementation to redirect call to 
safecall
- * \tparam Derived The derived class in CRTP-idiom
+ * \brief FunctionObj specialization for raw C style callback where handle and 
deleter are null.
  */
-template <typename Derived>
-struct RedirectCallToSafeCall {
-  static void Call(const FunctionObj* func, const AnyView* args, int32_t 
num_args, Any* rv) {
-    Derived* self = static_cast<Derived*>(const_cast<FunctionObj*>(func));
-    TVM_FFI_CHECK_SAFE_CALL(self->RedirectSafeCall(reinterpret_cast<const 
TVMFFIAny*>(args),
-                                                   num_args, 
reinterpret_cast<TVMFFIAny*>(rv)));
-  }
-
-  static int32_t SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, 
TVMFFIAny* rv) {
-    Derived* self = reinterpret_cast<Derived*>(func);
-    return self->RedirectSafeCall(args, num_args, rv);
+class ExternCFunctionObjNullHandleImpl : public FunctionObj {
+ public:
+  explicit ExternCFunctionObjNullHandleImpl(TVMFFISafeCallType safe_call) {
+    this->safe_call = safe_call;
+    this->cpp_call = nullptr;
   }
 };
 
 /*!
  * \brief FunctionObj specialization that leverages C-style callback 
definitions.
  */
-class ExternCFunctionObjImpl : public FunctionObj,
-                               public 
RedirectCallToSafeCall<ExternCFunctionObjImpl> {
+class ExternCFunctionObjImpl : public FunctionObj {
  public:
-  using RedirectCallToSafeCall<ExternCFunctionObjImpl>::SafeCall;
-
   ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void 
(*deleter)(void* self))
       : self_(self), safe_call_(safe_call), deleter_(deleter) {
-    this->call = RedirectCallToSafeCall<ExternCFunctionObjImpl>::Call;
-    this->safe_call = RedirectCallToSafeCall<ExternCFunctionObjImpl>::SafeCall;
+    this->safe_call = SafeCall;
+    this->cpp_call = nullptr;
   }
 
-  ~ExternCFunctionObjImpl() { deleter_(self_); }
-
-  TVM_FFI_INLINE int32_t RedirectSafeCall(const TVMFFIAny* args, int32_t 
num_args,
-                                          TVMFFIAny* rv) const {
-    return safe_call_(self_, args, num_args, rv);
+  ~ExternCFunctionObjImpl() {
+    if (deleter_) deleter_(self_);
   }
 
  private:
+  static int32_t SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, 
TVMFFIAny* rv) {
+    ExternCFunctionObjImpl* self = 
reinterpret_cast<ExternCFunctionObjImpl*>(func);
+    return self->safe_call_(self->self_, args, num_args, rv);
+  }
+
   void* self_;
   TVMFFISafeCallType safe_call_;
   void (*deleter_)(void* self);
 };
 
-/*!
- * \brief FunctionObj specialization that wraps an external function.
- */
-class ImportedFunctionObjImpl : public FunctionObj,
-                                public 
RedirectCallToSafeCall<ImportedFunctionObjImpl> {
- public:
-  using RedirectCallToSafeCall<ImportedFunctionObjImpl>::SafeCall;
-
-  explicit ImportedFunctionObjImpl(ObjectPtr<Object> data) : data_(data) {
-    this->call = RedirectCallToSafeCall<ImportedFunctionObjImpl>::Call;
-    this->safe_call = 
RedirectCallToSafeCall<ImportedFunctionObjImpl>::SafeCall;
-  }
-
-  TVM_FFI_INLINE int32_t RedirectSafeCall(const TVMFFIAny* args, int32_t 
num_args,
-                                          TVMFFIAny* rv) const {
-    FunctionObj* func = const_cast<FunctionObj*>(static_cast<const 
FunctionObj*>(data_.get()));
-    return func->safe_call(func, args, num_args, rv);
-  }
-
- private:
-  ObjectPtr<Object> data_;
-};
-
 // Helper class to set packed arguments
 class PackedArgsSetter {
  public:
@@ -353,31 +332,13 @@ class Function : public ObjectRef {
       return FromPackedInternal(packed_call);
     }
   }
+
   /*!
-   * \brief Import a possibly externally defined function to this dll
-   * \param other Function defined in another dynamic library.
+   * \brief Create ffi::Function from a C style callbacks.
    *
-   * \note This function will redirect the call to safe_call in other.
-   *  It will try to detect if the function is already from the same DLL
-   *  and directly return the original function if so.
+   * self and deleter can be nullptr if the function do not need closure 
support
+   * and corresponds to a raw function pointer.
    *
-   * \return The imported function.
-   */
-  static Function ImportFromExternDLL(Function other) {
-    const FunctionObj* other_func = static_cast<const 
FunctionObj*>(other.get());
-    // the other function comes from the same dll, no action needed
-    if (other_func->safe_call == &(FunctionObj::SafeCall) ||
-        other_func->safe_call == &(details::ImportedFunctionObjImpl::SafeCall) 
||
-        other_func->safe_call == &(details::ExternCFunctionObjImpl::SafeCall)) 
{
-      return other;
-    }
-    // the other function coems from a different library
-    Function func;
-    func.data_ = 
make_object<details::ImportedFunctionObjImpl>(std::move(other.data_));
-    return func;
-  }
-  /*!
-   * \brief Create ffi::Function from a C style callbacks.
    * \param self Resource handle to the function
    * \param safe_call The safe_call definition in C.
    * \param deleter The deleter to release the resource of self.
@@ -387,7 +348,11 @@ class Function : public ObjectRef {
                               void (*deleter)(void* self)) {
     // the other function coems from a different library
     Function func;
-    func.data_ = make_object<details::ExternCFunctionObjImpl>(self, safe_call, 
deleter);
+    if (self == nullptr && deleter == nullptr) {
+      func.data_ = 
make_object<details::ExternCFunctionObjNullHandleImpl>(safe_call);
+    } else {
+      func.data_ = make_object<details::ExternCFunctionObjImpl>(self, 
safe_call, deleter);
+    }
     return func;
   }
   /*!
@@ -854,19 +819,19 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) {
  *
  * \note The final symbol name is `__tvm_ffi_<ExportName>`.
  */
-#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function)                    
                \
-  extern "C" {                                                                 
                \
-  TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, TVMFFIAny* args, 
int32_t num_args, \
-                                                TVMFFIAny* result) {           
                \
-    TVM_FFI_SAFE_CALL_BEGIN();                                                 
                \
-    using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>;    
                \
-    static std::string name = #ExportName;                                     
                \
-    ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>(              
                \
-        std::make_index_sequence<FuncInfo::num_args>{}, &name, Function,       
                \
-        reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args,          
                \
-        reinterpret_cast<::tvm::ffi::Any*>(result));                           
                \
-    TVM_FFI_SAFE_CALL_END();                                                   
                \
-  }                                                                            
                \
+#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function)                    
        \
+  extern "C" {                                                                 
        \
+  TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, const TVMFFIAny* 
args,     \
+                                                int32_t num_args, TVMFFIAny* 
result) { \
+    TVM_FFI_SAFE_CALL_BEGIN();                                                 
        \
+    using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>;    
        \
+    static std::string name = #ExportName;                                     
        \
+    ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>(              
        \
+        std::make_index_sequence<FuncInfo::num_args>{}, &name, Function,       
        \
+        reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args,          
        \
+        reinterpret_cast<::tvm::ffi::Any*>(result));                           
        \
+    TVM_FFI_SAFE_CALL_END();                                                   
        \
+  }                                                                            
        \
   }
 }  // namespace ffi
 }  // namespace tvm
diff --git a/pyproject.toml b/pyproject.toml
index 4a6a734..ff52507 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
 
 [project]
 name = "apache-tvm-ffi"
-version = "0.1.0b11"
+version = "0.1.0b12"
 description = "tvm ffi"
 
 authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index d88a35e..9a9ed83 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -17,7 +17,7 @@
 """TVM FFI Python package."""
 
 # version
-__version__ = "0.1.0b11"
+__version__ = "0.1.0b12"
 
 # order matters here so we need to skip isort here
 # isort: skip_file
diff --git a/tests/cpp/test_function.cc b/tests/cpp/test_function.cc
index c3c484f..674bb75 100644
--- a/tests/cpp/test_function.cc
+++ b/tests/cpp/test_function.cc
@@ -236,4 +236,14 @@ TEST(Func, ObjectRefWithFallbackTraits) {
       ::tvm::ffi::Error);
 }
 
+int testing_add1(int x) { return x + 1; }
+
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(testing_add1, testing_add1);
+
+TEST(Func, FromExternC) {
+  // this is the function abi convention
+  Function fadd1 = Function::FromExternC(nullptr, __tvm_ffi_testing_add1, 
nullptr);
+  EXPECT_EQ(fadd1(1).cast<int>(), 2);
+}
+
 }  // namespace

Reply via email to