This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 889bfb3 [Minor] use perfect forwarding for template types. (#266)
889bfb3 is described below
commit 889bfb360b5afa6f7b50774cba53e7e625ea1e8f
Author: DarkSharpness <[email protected]>
AuthorDate: Mon Nov 17 23:22:24 2025 +0800
[Minor] use perfect forwarding for template types. (#266)
We use copy-and-move for types cheap to copy (e.g. `Function`,
`std::string`), and use perfect forwarding to avoid any unnecessary
copy/move construction during forwarding chain.
---
include/tvm/ffi/function.h | 86 ++++++++++++++++++++++++++++------------------
1 file changed, 52 insertions(+), 34 deletions(-)
diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h
index 8ef3f23..6db5435 100644
--- a/include/tvm/ffi/function.h
+++ b/include/tvm/ffi/function.h
@@ -31,6 +31,7 @@
#include <functional>
#include <string>
+#include <type_traits>
#include <utility>
#include <vector>
@@ -141,18 +142,30 @@ namespace details {
* \brief Derived object class for constructing FunctionObj backed by a
TCallable
*
* This is a helper class that implements the function call interface.
+ * Invariance: TCallable cannot be const or reference type.
*/
template <typename TCallable>
class FunctionObjImpl : public FunctionObj {
public:
- using TStorage = std::remove_cv_t<std::remove_reference_t<TCallable>>;
+ static_assert(std::is_same_v<TCallable,
std::remove_cv_t<std::remove_reference_t<TCallable>>>,
+ "TCallable of FunctionObjImpl cannot be const or reference
type");
+
/*! \brief The type of derived object class */
using TSelf = FunctionObjImpl<TCallable>;
+
/*!
* \brief Derived object class for constructing ffi::FunctionObj.
- * \param callable The type-erased callable object.
+ * \param callable The type-erased callable object (rvalue).
*/
- explicit FunctionObjImpl(TCallable callable) :
callable_(std::move(callable)) {
+ explicit FunctionObjImpl(TCallable&& callable) :
callable_(std::move(callable)) {
+ this->safe_call = SafeCall;
+ this->cpp_call = reinterpret_cast<void*>(CppCall);
+ }
+ /*!
+ * \brief Derived object class for constructing ffi::FunctionObj.
+ * \param callable The type-erased callable object (lvalue).
+ */
+ explicit FunctionObjImpl(const TCallable& callable) : callable_(callable) {
this->safe_call = SafeCall;
this->cpp_call = reinterpret_cast<void*>(CppCall);
}
@@ -174,7 +187,7 @@ class FunctionObjImpl : public FunctionObj {
}
/// \endcond
/*! \brief Type-erased filed for storing callable object*/
- mutable TStorage callable_;
+ mutable TCallable callable_;
};
/*!
@@ -305,17 +318,18 @@ class Function : public ObjectRef {
* \param packed_call The packed function signature
* \note legacy purpose, should change to Function::FromPacked for
mostfuture use.
*/
- template <typename TCallable>
- explicit Function(TCallable packed_call) {
- *this = FromPacked(packed_call);
+ template <typename TCallable,
+ typename =
std::enable_if_t<!std::is_same_v<std::decay_t<TCallable>, Function>>>
+ explicit Function(TCallable&& packed_call) {
+ *this = FromPacked(std::forward<TCallable>(packed_call));
}
/*!
* \brief Constructing a packed function from a callable type
* whose signature is consistent with `ffi::Function`
* \param packed_call The packed function signature
*/
- template <typename TCallable> // //
NOLINTNEXTLINE(performance-unnecessary-value-param)
- static Function FromPacked(TCallable packed_call) {
+ template <typename TCallable>
+ static Function FromPacked(TCallable&& packed_call) {
static_assert(
std::is_convertible_v<TCallable, std::function<void(const AnyView*,
int32_t, Any*)>> ||
std::is_convertible_v<TCallable, std::function<void(PackedArgs
args, Any*)>>,
@@ -323,12 +337,12 @@ class Function : public ObjectRef {
"format");
if constexpr (std::is_convertible_v<TCallable,
std::function<void(PackedArgs args, Any*)>>) {
return FromPackedInternal(
- [packed_call](const AnyView* args, int32_t num_args, Any* rv)
mutable -> void {
- PackedArgs args_pack(args, num_args);
- packed_call(args_pack, rv);
+ [packed_call = std::forward<TCallable>(packed_call)](
+ const AnyView* args, int32_t num_args, Any* rv) mutable -> void {
+ packed_call(PackedArgs{args, num_args}, rv);
});
} else {
- return FromPackedInternal(packed_call);
+ return FromPackedInternal(std::forward<TCallable>(packed_call));
}
}
@@ -487,10 +501,11 @@ class Function : public ObjectRef {
* \param callable the internal container of packed function.
*/
template <typename TCallable>
- static Function FromTyped(TCallable callable) {
- using FuncInfo = details::FunctionInfo<TCallable>;
- auto call_packed = [callable = std::move(callable)](const AnyView* args,
int32_t num_args,
- Any* rv) mutable ->
void {
+ static Function FromTyped(TCallable&& callable) {
+ using FuncInfo = details::FunctionInfo<std::decay_t<TCallable>>;
+ // Callable is always captured by value here to avoid possible dangling
reference
+ auto call_packed = [callable = std::forward<TCallable>(callable)](
+ const AnyView* args, int32_t num_args, Any* rv)
mutable -> void {
details::unpack_call<typename FuncInfo::RetType>(
std::make_index_sequence<FuncInfo::num_args>{}, nullptr, callable,
args, num_args, rv);
};
@@ -503,9 +518,10 @@ class Function : public ObjectRef {
* \param name optional name attacked to the function.
*/
template <typename TCallable>
- static Function FromTyped(TCallable callable, std::string name) {
- using FuncInfo = details::FunctionInfo<TCallable>;
- auto call_packed = [callable = std::move(callable), name =
std::move(name)](
+ static Function FromTyped(TCallable&& callable, std::string name) {
+ using FuncInfo = details::FunctionInfo<std::decay_t<TCallable>>;
+ // Callable is always captured by value here to avoid possible dangling
reference
+ auto call_packed = [callable = std::forward<TCallable>(callable), name =
std::move(name)](
const AnyView* args, int32_t num_args, Any* rv)
mutable -> void {
details::unpack_call<typename FuncInfo::RetType>(
std::make_index_sequence<FuncInfo::num_args>{}, &name, callable,
args, num_args, rv);
@@ -611,11 +627,11 @@ class Function : public ObjectRef {
* \param packed_call The packed function signature
*/
template <typename TCallable>
- static Function FromPackedInternal(TCallable packed_call) {
- using ObjType = typename details::FunctionObjImpl<TCallable>;
+ static Function FromPackedInternal(TCallable&& packed_call) {
+ // We must make TCallable a value type (decay_t) that can hold the
callable object
+ using ObjType = typename details::FunctionObjImpl<std::decay_t<TCallable>>;
Function func;
- func.data_ = make_object<ObjType>(
- std::forward<TCallable>(packed_call)); //
NOLINT(bugprone-chained-comparison)
+ func.data_ = make_object<ObjType>(std::forward<TCallable>(packed_call));
return func;
}
};
@@ -671,7 +687,7 @@ class TypedFunction<R(Args...)> {
* \brief constructor from a function
* \param packed The function
*/
- TypedFunction(Function packed) : packed_(packed) {} // NOLINT(*)
+ TypedFunction(Function packed) : packed_(std::move(packed)) {} // NOLINT(*)
/*!
* \brief construct from a lambda function with the same signature.
*
@@ -690,8 +706,8 @@ class TypedFunction<R(Args...)> {
*/
template <typename FLambda,
typename = std::enable_if_t<std::is_convertible_v<FLambda,
std::function<R(Args...)>>>>
- TypedFunction(FLambda typed_lambda, std::string name) { // NOLINT(*)
- packed_ = Function::FromTyped(typed_lambda, name);
+ TypedFunction(FLambda&& typed_lambda, std::string name) {
+ packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda),
std::move(name));
}
/*!
* \brief construct from a lambda function with the same signature.
@@ -712,9 +728,10 @@ class TypedFunction<R(Args...)> {
* \tparam FLambda the type of the lambda function.
*/
template <typename FLambda,
- typename = std::enable_if_t<std::is_convertible_v<FLambda,
std::function<R(Args...)>>>>
- TypedFunction(const FLambda& typed_lambda) { // NOLINT(*)
- packed_ = Function::FromTyped(typed_lambda);
+ typename = std::enable_if_t<std::is_convertible_v<FLambda,
std::function<R(Args...)>> &&
+ !std::is_same_v<std::decay_t<FLambda>,
TSelf>>>
+ TypedFunction(FLambda&& typed_lambda) { //
NOLINT(google-explicit-constructor)
+ packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda));
}
/*!
* \brief copy assignment operator from typed lambda
@@ -733,9 +750,10 @@ class TypedFunction<R(Args...)> {
* \returns reference to self.
*/
template <typename FLambda,
- typename = std::enable_if_t<std::is_convertible_v<FLambda,
std::function<R(Args...)>>>>
- TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
- packed_ = Function::FromTyped(typed_lambda);
+ typename = std::enable_if_t<std::is_convertible_v<FLambda,
std::function<R(Args...)>> &&
+ !std::is_same_v<std::decay_t<FLambda>,
TSelf>>>
+ TSelf& operator=(FLambda&& typed_lambda) {
+ packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda));
return *this;
}
/*!
@@ -752,7 +770,7 @@ class TypedFunction<R(Args...)> {
* \param args The arguments
* \returns The return value.
*/
- TVM_FFI_INLINE R operator()(Args... args) const { //
NOLINT(performance-unnecessary-value-param)
+ TVM_FFI_INLINE R operator()(Args... args) const {
if constexpr (std::is_same_v<R, void>) {
packed_(std::forward<Args>(args)...);
} else {