This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 889bfb3  [Minor] use perfect forwarding for template types. (#266)
889bfb3 is described below

commit 889bfb360b5afa6f7b50774cba53e7e625ea1e8f
Author: DarkSharpness <[email protected]>
AuthorDate: Mon Nov 17 23:22:24 2025 +0800

    [Minor] use perfect forwarding for template types. (#266)
    
    We use copy-and-move for types cheap to copy (e.g. `Function`,
    `std::string`), and use perfect forwarding to avoid any unnecessary
    copy/move construction during forwarding chain.
---
 include/tvm/ffi/function.h | 86 ++++++++++++++++++++++++++++------------------
 1 file changed, 52 insertions(+), 34 deletions(-)

diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h
index 8ef3f23..6db5435 100644
--- a/include/tvm/ffi/function.h
+++ b/include/tvm/ffi/function.h
@@ -31,6 +31,7 @@
 
 #include <functional>
 #include <string>
+#include <type_traits>
 #include <utility>
 #include <vector>
 
@@ -141,18 +142,30 @@ namespace details {
  * \brief Derived object class for constructing FunctionObj backed by a 
TCallable
  *
  * This is a helper class that implements the function call interface.
+ * Invariance: TCallable cannot be const or reference type.
  */
 template <typename TCallable>
 class FunctionObjImpl : public FunctionObj {
  public:
-  using TStorage = std::remove_cv_t<std::remove_reference_t<TCallable>>;
+  static_assert(std::is_same_v<TCallable, 
std::remove_cv_t<std::remove_reference_t<TCallable>>>,
+                "TCallable of FunctionObjImpl cannot be const or reference 
type");
+
   /*! \brief The type of derived object class */
   using TSelf = FunctionObjImpl<TCallable>;
+
   /*!
    * \brief Derived object class for constructing ffi::FunctionObj.
-   * \param callable The type-erased callable object.
+   * \param callable The type-erased callable object (rvalue).
    */
-  explicit FunctionObjImpl(TCallable callable) : 
callable_(std::move(callable)) {
+  explicit FunctionObjImpl(TCallable&& callable) : 
callable_(std::move(callable)) {
+    this->safe_call = SafeCall;
+    this->cpp_call = reinterpret_cast<void*>(CppCall);
+  }
+  /*!
+   * \brief Derived object class for constructing ffi::FunctionObj.
+   * \param callable The type-erased callable object (lvalue).
+   */
+  explicit FunctionObjImpl(const TCallable& callable) : callable_(callable) {
     this->safe_call = SafeCall;
     this->cpp_call = reinterpret_cast<void*>(CppCall);
   }
@@ -174,7 +187,7 @@ class FunctionObjImpl : public FunctionObj {
   }
   /// \endcond
   /*! \brief Type-erased filed for storing callable object*/
-  mutable TStorage callable_;
+  mutable TCallable callable_;
 };
 
 /*!
@@ -305,17 +318,18 @@ class Function : public ObjectRef {
    * \param packed_call The packed function signature
    * \note legacy purpose, should change to Function::FromPacked for 
mostfuture use.
    */
-  template <typename TCallable>
-  explicit Function(TCallable packed_call) {
-    *this = FromPacked(packed_call);
+  template <typename TCallable,
+            typename = 
std::enable_if_t<!std::is_same_v<std::decay_t<TCallable>, Function>>>
+  explicit Function(TCallable&& packed_call) {
+    *this = FromPacked(std::forward<TCallable>(packed_call));
   }
   /*!
    * \brief Constructing a packed function from a callable type
    *        whose signature is consistent with `ffi::Function`
    * \param packed_call The packed function signature
    */
-  template <typename TCallable>  // // 
NOLINTNEXTLINE(performance-unnecessary-value-param)
-  static Function FromPacked(TCallable packed_call) {
+  template <typename TCallable>
+  static Function FromPacked(TCallable&& packed_call) {
     static_assert(
         std::is_convertible_v<TCallable, std::function<void(const AnyView*, 
int32_t, Any*)>> ||
             std::is_convertible_v<TCallable, std::function<void(PackedArgs 
args, Any*)>>,
@@ -323,12 +337,12 @@ class Function : public ObjectRef {
         "format");
     if constexpr (std::is_convertible_v<TCallable, 
std::function<void(PackedArgs args, Any*)>>) {
       return FromPackedInternal(
-          [packed_call](const AnyView* args, int32_t num_args, Any* rv) 
mutable -> void {
-            PackedArgs args_pack(args, num_args);
-            packed_call(args_pack, rv);
+          [packed_call = std::forward<TCallable>(packed_call)](
+              const AnyView* args, int32_t num_args, Any* rv) mutable -> void {
+            packed_call(PackedArgs{args, num_args}, rv);
           });
     } else {
-      return FromPackedInternal(packed_call);
+      return FromPackedInternal(std::forward<TCallable>(packed_call));
     }
   }
 
@@ -487,10 +501,11 @@ class Function : public ObjectRef {
    * \param callable the internal container of packed function.
    */
   template <typename TCallable>
-  static Function FromTyped(TCallable callable) {
-    using FuncInfo = details::FunctionInfo<TCallable>;
-    auto call_packed = [callable = std::move(callable)](const AnyView* args, 
int32_t num_args,
-                                                        Any* rv) mutable -> 
void {
+  static Function FromTyped(TCallable&& callable) {
+    using FuncInfo = details::FunctionInfo<std::decay_t<TCallable>>;
+    // Callable is always captured by value here to avoid possible dangling 
reference
+    auto call_packed = [callable = std::forward<TCallable>(callable)](
+                           const AnyView* args, int32_t num_args, Any* rv) 
mutable -> void {
       details::unpack_call<typename FuncInfo::RetType>(
           std::make_index_sequence<FuncInfo::num_args>{}, nullptr, callable, 
args, num_args, rv);
     };
@@ -503,9 +518,10 @@ class Function : public ObjectRef {
    * \param name optional name attacked to the function.
    */
   template <typename TCallable>
-  static Function FromTyped(TCallable callable, std::string name) {
-    using FuncInfo = details::FunctionInfo<TCallable>;
-    auto call_packed = [callable = std::move(callable), name = 
std::move(name)](
+  static Function FromTyped(TCallable&& callable, std::string name) {
+    using FuncInfo = details::FunctionInfo<std::decay_t<TCallable>>;
+    // Callable is always captured by value here to avoid possible dangling 
reference
+    auto call_packed = [callable = std::forward<TCallable>(callable), name = 
std::move(name)](
                            const AnyView* args, int32_t num_args, Any* rv) 
mutable -> void {
       details::unpack_call<typename FuncInfo::RetType>(
           std::make_index_sequence<FuncInfo::num_args>{}, &name, callable, 
args, num_args, rv);
@@ -611,11 +627,11 @@ class Function : public ObjectRef {
    * \param packed_call The packed function signature
    */
   template <typename TCallable>
-  static Function FromPackedInternal(TCallable packed_call) {
-    using ObjType = typename details::FunctionObjImpl<TCallable>;
+  static Function FromPackedInternal(TCallable&& packed_call) {
+    // We must make TCallable a value type (decay_t) that can hold the 
callable object
+    using ObjType = typename details::FunctionObjImpl<std::decay_t<TCallable>>;
     Function func;
-    func.data_ = make_object<ObjType>(
-        std::forward<TCallable>(packed_call));  // 
NOLINT(bugprone-chained-comparison)
+    func.data_ = make_object<ObjType>(std::forward<TCallable>(packed_call));
     return func;
   }
 };
@@ -671,7 +687,7 @@ class TypedFunction<R(Args...)> {
    * \brief constructor from a function
    * \param packed The function
    */
-  TypedFunction(Function packed) : packed_(packed) {}  // NOLINT(*)
+  TypedFunction(Function packed) : packed_(std::move(packed)) {}  // NOLINT(*)
   /*!
    * \brief construct from a lambda function with the same signature.
    *
@@ -690,8 +706,8 @@ class TypedFunction<R(Args...)> {
    */
   template <typename FLambda,
             typename = std::enable_if_t<std::is_convertible_v<FLambda, 
std::function<R(Args...)>>>>
-  TypedFunction(FLambda typed_lambda, std::string name) {  // NOLINT(*)
-    packed_ = Function::FromTyped(typed_lambda, name);
+  TypedFunction(FLambda&& typed_lambda, std::string name) {
+    packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda), 
std::move(name));
   }
   /*!
    * \brief construct from a lambda function with the same signature.
@@ -712,9 +728,10 @@ class TypedFunction<R(Args...)> {
    * \tparam FLambda the type of the lambda function.
    */
   template <typename FLambda,
-            typename = std::enable_if_t<std::is_convertible_v<FLambda, 
std::function<R(Args...)>>>>
-  TypedFunction(const FLambda& typed_lambda) {  // NOLINT(*)
-    packed_ = Function::FromTyped(typed_lambda);
+            typename = std::enable_if_t<std::is_convertible_v<FLambda, 
std::function<R(Args...)>> &&
+                                        !std::is_same_v<std::decay_t<FLambda>, 
TSelf>>>
+  TypedFunction(FLambda&& typed_lambda) {  // 
NOLINT(google-explicit-constructor)
+    packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda));
   }
   /*!
    * \brief copy assignment operator from typed lambda
@@ -733,9 +750,10 @@ class TypedFunction<R(Args...)> {
    * \returns reference to self.
    */
   template <typename FLambda,
-            typename = std::enable_if_t<std::is_convertible_v<FLambda, 
std::function<R(Args...)>>>>
-  TSelf& operator=(FLambda typed_lambda) {  // NOLINT(*)
-    packed_ = Function::FromTyped(typed_lambda);
+            typename = std::enable_if_t<std::is_convertible_v<FLambda, 
std::function<R(Args...)>> &&
+                                        !std::is_same_v<std::decay_t<FLambda>, 
TSelf>>>
+  TSelf& operator=(FLambda&& typed_lambda) {
+    packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda));
     return *this;
   }
   /*!
@@ -752,7 +770,7 @@ class TypedFunction<R(Args...)> {
    * \param args The arguments
    * \returns The return value.
    */
-  TVM_FFI_INLINE R operator()(Args... args) const {  // 
NOLINT(performance-unnecessary-value-param)
+  TVM_FFI_INLINE R operator()(Args... args) const {
     if constexpr (std::is_same_v<R, void>) {
       packed_(std::forward<Args>(args)...);
     } else {

Reply via email to