This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 3dd7a81 [REFACTOR][FEAT] Introduce generic value protocol (#312)
3dd7a81 is described below
commit 3dd7a8173363bdf79806610818121e83e99b3b56
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Dec 4 18:57:46 2025 -0500
[REFACTOR][FEAT] Introduce generic value protocol (#312)
This is a generic protocol that gives ability for
classes to declare how they will convert to tvm-ffi compatible values.
Also did a round of refactor to consolidate classes in python helpers.
---
python/tvm_ffi/cython/base.pxi | 8 +
python/tvm_ffi/cython/function.pxi | 19 +++
python/tvm_ffi/cython/tensor.pxi | 2 +-
python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 228 +++++++++++++++----------
tests/python/test_function.py | 19 +++
tests/python/test_stream.py | 3 +
tests/python/test_tensor.py | 23 +++
tests/scripts/benchmark_dlpack.py | 44 +++--
8 files changed, 246 insertions(+), 100 deletions(-)
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index a24a5d8..933bc86 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -364,10 +364,18 @@ cdef extern from "tvm_ffi_python_helpers.h":
int* c_api_ret_code
) except -1
+ int TVMFFIPySetArgumentGenericDispatcher(
+ TVMFFIPyArgSetterFactory setter_factory,
+ TVMFFIPyCallContext* ctx,
+ PyObject* py_arg,
+ TVMFFIAny* out
+ ) except -1
+
size_t TVMFFIPyGetDispatchMapSize() noexcept
void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx,
TVMFFIObjectHandle arg) noexcept
void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg)
noexcept
+ void TVMFFIPyPushExtraTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg)
# the predefined setters for common POD types
int TVMFFIPyArgSetterFloat_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*,
PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*,
PyObject* arg, TVMFFIAny* out) except -1
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 189a6fc..29af699 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -674,6 +674,22 @@ cdef int TVMFFIPyArgSetterFloatProtocol_(
return 0
+cdef int TVMFFIPyArgSetterFFIValueProtocol_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for class with __tvm_ffi_value__() method"""
+ cdef object arg = <object>py_arg
+ cdef object ffi_value_py_obj = arg.__tvm_ffi_value__()
+ cdef PyObject* ffi_value_py_obj_ptr = <PyObject*>ffi_value_py_obj
+ # keep alive the python object since this is a temporary object
+ # we must push to extra temp py objects stack to avoid overflow the temp
py objects stack
+ TVMFFIPyPushExtraTempPyObject(ctx, ffi_value_py_obj_ptr)
+ return TVMFFIPySetArgumentGenericDispatcher(
+ TVMFFIPyArgSetterFactory_, ctx, ffi_value_py_obj_ptr, out
+ )
+
+
cdef _DISPATCH_TYPE_KEEP_ALIVE = set()
cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
@@ -824,6 +840,9 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
if hasattr(arg_class, "__tvm_ffi_float__"):
out.func = TVMFFIPyArgSetterFloatProtocol_
return 0
+ if hasattr(arg_class, "__tvm_ffi_value__"):
+ out.func = TVMFFIPyArgSetterFFIValueProtocol_
+ return 0
if isinstance(arg, Exception):
out.func = TVMFFIPyArgSetterException_
return 0
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 844f7ef..841687a 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -438,7 +438,7 @@ def _dltensor_test_wrapper_exchange_api_ptr():
cdef class DLTensorTestWrapper:
"""Wrapper of a Tensor that exposes DLPack protocol, only for testing
purpose.
"""
- __c_dlpack_exchange_api__: int = _dltensor_test_wrapper_exchange_api_ptr()
+ __c_dlpack_exchange_api__ = _dltensor_test_wrapper_exchange_api_ptr()
cdef Tensor tensor
cdef dict __dict__
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index dc970d3..88c27d7 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -40,16 +40,43 @@
#include <exception>
#include <iostream>
#include <unordered_map>
+#include <vector>
///--------------------------------------------------------------------------------
/// We deliberately designed the data structure and function to be C-style
// prefixed with TVMFFIPy so they can be easily invoked through Cython.
///--------------------------------------------------------------------------------
+/*!
+ * \brief Thread-local call stack used by TVMFFIPyCallContext.
+ */
+class TVMFFIPyCallStack {
+ public:
+ /*! \brief The stack of arguments */
+ std::vector<TVMFFIAny> args_stack;
+ /*! \brief The top of the argument call stack currently */
+ int64_t args_stack_top = 0;
+ /*!
+ * \brief The stack of extra temporary Python objects that may not fit into
+ * one temp per argument budget, mainly used by value protocol.
+ */
+ std::vector<PyObject*> extra_temp_py_objects_stack;
+
+ /*! \brief Constructor to initialize the call stack */
+ TVMFFIPyCallStack() {
+ // keep it 4K as default stack size so it is page aligned
+ constexpr size_t kDefaultStackSize = 4096;
+ // fit everything roughly 4K stack
+ args_stack.resize(kDefaultStackSize / sizeof(TVMFFIAny));
+ extra_temp_py_objects_stack.reserve(16);
+ }
+};
+
/*!
* \brief Context for each ffi call to track the stream, device and temporary
arguments.
*/
-struct TVMFFIPyCallContext {
+class TVMFFIPyCallContext {
+ public:
/*! \brief The workspace for the packed arguments */
TVMFFIAny* packed_args = nullptr;
/*! \brief Detected device type, if any */
@@ -58,16 +85,77 @@ struct TVMFFIPyCallContext {
int device_id = 0;
/*! \brief Detected stream, if any */
void* stream = nullptr;
+ /*! \brief the DLPack exchange API, if any */
+ const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
+ /*! \brief pointer to the call stack space */
+ TVMFFIPyCallStack* call_stack = nullptr;
/*! \brief the temporary arguments to be recycled */
void** temp_ffi_objects = nullptr;
- /*! \brief the number of temporary arguments */
- int num_temp_ffi_objects = 0;
/*! \brief the temporary arguments to be recycled */
void** temp_py_objects = nullptr;
/*! \brief the number of temporary arguments */
+ int num_temp_ffi_objects = 0;
+ /*! \brief the number of temporary arguments */
int num_temp_py_objects = 0;
- /*! \brief the DLPack exchange API, if any */
- const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
+
+ /*! \brief RAII guard constructor to create a TVMFFIPyCallContext */
+ TVMFFIPyCallContext(TVMFFIPyCallStack* call_stack, int64_t num_args) :
call_stack(call_stack) {
+ // In most cases, it will try to allocate from temp_stack,
+ // then allocate from heap if the request goes beyond the stack size.
+ static_assert(sizeof(TVMFFIAny) >= (sizeof(void*) * 2));
+ static_assert(alignof(TVMFFIAny) % alignof(void*) == 0);
+ old_args_stack_top_ = call_stack->args_stack_top;
+ int64_t requested_count = num_args * 2;
+ TVMFFIAny* stack_head = call_stack->args_stack.data() +
call_stack->args_stack_top;
+ if (call_stack->args_stack_top + requested_count >
+ static_cast<int64_t>(call_stack->args_stack.size())) {
+ // allocate from heap
+ heap_ptr_ = new TVMFFIAny[requested_count];
+ stack_head = heap_ptr_;
+ } else {
+ call_stack->args_stack_top += requested_count;
+ }
+ this->packed_args = stack_head;
+ // by default we co-locate the temporary arguments with packed arguments
+ // for better cache locality with one temp per argument budget.
+ this->temp_ffi_objects = reinterpret_cast<void**>(stack_head + num_args);
+ this->temp_py_objects = this->temp_ffi_objects + num_args;
+ this->old_extra_temp_py_objects_stack_top_ =
call_stack->extra_temp_py_objects_stack.size();
+ }
+
+ ~TVMFFIPyCallContext() {
+ try {
+ // recycle the temporary arguments if any
+ for (int i = 0; i < this->num_temp_ffi_objects; ++i) {
+ TVMFFIObjectDecRef(this->temp_ffi_objects[i]);
+ }
+ for (int i = 0; i < this->num_temp_py_objects; ++i) {
+ Py_DecRef(static_cast<PyObject*>(this->temp_py_objects[i]));
+ }
+ for (size_t i = old_extra_temp_py_objects_stack_top_;
+ i < call_stack->extra_temp_py_objects_stack.size(); ++i) {
+
Py_DecRef(static_cast<PyObject*>(call_stack->extra_temp_py_objects_stack[i]));
+ }
+
call_stack->extra_temp_py_objects_stack.resize(old_extra_temp_py_objects_stack_top_);
+ } catch (const std::exception& ex) {
+ // very rare, catch c++ exception and set python error
+ PyErr_SetString(PyExc_RuntimeError, ex.what());
+ }
+ // now recycle the memory of the call stack
+ if (heap_ptr_ == nullptr) {
+ call_stack->args_stack_top = old_args_stack_top_;
+ } else {
+ delete[] heap_ptr_;
+ }
+ }
+
+ private:
+ /*! \brief the heap pointer */
+ TVMFFIAny* heap_ptr_ = nullptr;
+ /*! \brief the old stack top */
+ size_t old_args_stack_top_;
+ /*! \brief the begin index of the temporary Python objects stack */
+ size_t old_extra_temp_py_objects_stack_top_;
};
/*! \brief Argument setter for a given python argument. */
@@ -173,66 +261,6 @@ class TVMFFIPyCallManager {
static thread_local TVMFFIPyCallManager inst;
return &inst;
}
- /*!
- * \brief auxiliary class that manages the call stack in RAII manner.
- *
- * In most cases, it will try to allocate from temp_stack,
- * then allocate from heap if the request goes beyond the stack size.
- */
- class CallStack : public TVMFFIPyCallContext {
- public:
- CallStack(TVMFFIPyCallManager* manager, int64_t num_args) :
manager_ptr_(manager) {
- static_assert(sizeof(TVMFFIAny) >= (sizeof(void*) * 2));
- static_assert(alignof(TVMFFIAny) % alignof(void*) == 0);
- old_stack_top_ = manager->stack_top_;
- int64_t requested_count = num_args * 2;
- TVMFFIAny* stack_head = manager->temp_stack_.data() +
manager->stack_top_;
- if (manager->stack_top_ + requested_count >
- static_cast<int64_t>(manager->temp_stack_.size())) {
- // allocate from heap
- heap_ptr_ = new TVMFFIAny[requested_count];
- stack_head = heap_ptr_;
- } else {
- manager->stack_top_ += requested_count;
- }
- this->packed_args = stack_head;
- this->temp_ffi_objects = reinterpret_cast<void**>(stack_head + num_args);
- this->temp_py_objects = this->temp_ffi_objects + num_args;
- }
-
- ~CallStack() {
- try {
- // recycle the temporary arguments if any
- for (int i = 0; i < this->num_temp_ffi_objects; ++i) {
- TVMFFIObjectDecRef(this->temp_ffi_objects[i]);
- }
- for (int i = 0; i < this->num_temp_py_objects; ++i) {
- Py_DecRef(static_cast<PyObject*>(this->temp_py_objects[i]));
- }
- } catch (const std::exception& ex) {
- // very rare, catch c++ exception and set python error
- PyErr_SetString(PyExc_RuntimeError, ex.what());
- }
- // now recycle the memory of the call stack
- if (heap_ptr_ == nullptr) {
- manager_ptr_->stack_top_ = old_stack_top_;
- } else {
- delete[] heap_ptr_;
- }
- }
-
- private:
- /*!
- *\brief The manager of the call stack
- * If stored on stack, must set it to point to parent.
- */
- TVMFFIPyCallManager* manager_ptr_ = nullptr;
- /*! \brief The heap of the call stack */
- TVMFFIAny* heap_ptr_ = nullptr;
- /*! \brief The old stack size */
- int64_t old_stack_top_ = 0;
- };
-
/*!
* \brief Call a function with a variable number of arguments
* \param setter_factory The factory function to create the setter
@@ -253,7 +281,7 @@ class TVMFFIPyCallManager {
if (num_args == -1) return -1;
try {
// allocate a call stack
- CallStack ctx(this, num_args);
+ TVMFFIPyCallContext ctx(&call_stack_, num_args);
// Iterate over the arguments and set them
for (int64_t i = 0; i < num_args; ++i) {
PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i);
@@ -335,7 +363,7 @@ class TVMFFIPyCallManager {
if (num_args == -1) return -1;
try {
// allocate a call stack
- CallStack ctx(this, num_args);
+ TVMFFIPyCallContext ctx(&call_stack_, num_args);
// Iterate over the arguments and set them
for (int64_t i = 0; i < num_args; ++i) {
PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i);
@@ -368,7 +396,7 @@ class TVMFFIPyCallManager {
TVMFFIFieldSetter field_setter, void* field_ptr,
PyObject* py_arg,
int* c_api_ret_code) {
try {
- CallStack ctx(this, 1);
+ TVMFFIPyCallContext ctx(&call_stack_, 1);
TVMFFIAny* c_arg = ctx.packed_args;
if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
c_api_ret_code[0] = (*field_setter)(field_ptr, c_arg);
@@ -380,10 +408,10 @@ class TVMFFIPyCallManager {
}
}
- int PyObjectToFFIAny(TVMFFIPyArgSetterFactory setter_factory, PyObject*
py_arg, TVMFFIAny* out,
- int* c_api_ret_code) {
+ TVM_FFI_INLINE int PyObjectToFFIAny(TVMFFIPyArgSetterFactory setter_factory,
PyObject* py_arg,
+ TVMFFIAny* out, int* c_api_ret_code) {
try {
- CallStack ctx(this, 1);
+ TVMFFIPyCallContext ctx(&call_stack_, 1);
TVMFFIAny* c_arg = ctx.packed_args;
if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
c_api_ret_code[0] = TVMFFIAnyViewToOwnedAny(c_arg, out);
@@ -394,20 +422,7 @@ class TVMFFIPyCallManager {
return -1;
}
}
- /*!
- * \brief Get the size of the dispatch map
- * \return The size of the dispatch map
- */
- size_t GetDispatchMapSize() const { return dispatch_map_.size(); }
- private:
- TVMFFIPyCallManager() {
- static constexpr size_t kDefaultDispatchCapacity = 32;
- // keep it 4K as default stack size so it is page aligned
- static constexpr size_t kDefaultStackSize = 4096;
- dispatch_map_.reserve(kDefaultDispatchCapacity);
- temp_stack_.resize(kDefaultStackSize / sizeof(TVMFFIAny));
- }
/*!
* \brief Set an py_arg to out.
* \param setter_factory The factory function to create the setter
@@ -443,11 +458,23 @@ class TVMFFIPyCallManager {
}
return 0;
}
+
+ /*!
+ * \brief Get the size of the dispatch map
+ * \return The size of the dispatch map
+ */
+ size_t GetDispatchMapSize() const { return dispatch_map_.size(); }
+
+ private:
+ TVMFFIPyCallManager() {
+ static constexpr size_t kDefaultDispatchCapacity = 32;
+ dispatch_map_.reserve(kDefaultDispatchCapacity);
+ }
+
// internal dispacher
std::unordered_map<PyTypeObject*, TVMFFIPyArgSetter> dispatch_map_;
- // temp call stack
- std::vector<TVMFFIAny> temp_stack_;
- int64_t stack_top_ = 0;
+ // call stack
+ TVMFFIPyCallStack call_stack_;
};
/*!
@@ -514,6 +541,22 @@ TVM_FFI_INLINE int
TVMFFIPyCallFieldSetter(TVMFFIPyArgSetterFactory setter_facto
py_arg, c_api_ret_code);
}
+/*!
+ * \brief Set an python argument to a FFI Any using the generic dispatcher in
call manager
+ * \param setter_factory The factory function to create the setter
+ * \param ctx The call context
+ * \param py_arg_tvm_ffi_value The python argument to be set using the
__tvm_ffi_value__ protocol
+ * \param out The output argument
+ * \return 0 on success, nonzero on failure
+ */
+TVM_FFI_INLINE int
TVMFFIPySetArgumentGenericDispatcher(TVMFFIPyArgSetterFactory setter_factory,
+ TVMFFIPyCallContext*
ctx,
+ PyObject*
py_arg_tvm_ffi_value,
+ TVMFFIAny* out) {
+ return TVMFFIPyCallManager::ThreadLocal()->SetArgument(setter_factory, ctx,
py_arg_tvm_ffi_value,
+ out);
+}
+
/*!
* \brief Convert a Python object to a FFI Any
* \param setter_factory The factory function to create the setter
@@ -560,6 +603,17 @@ TVM_FFI_INLINE void
TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject*
ctx->temp_py_objects[ctx->num_temp_py_objects++] = arg;
}
+/*!
+ * \brief Push Extra temporary Python object to the call context that may go
beyond one temp per
+ * argument budget, mainly used by value protocol.
+ * \param ctx The call context
+ * \param arg The Python object to push
+ */
+TVM_FFI_INLINE void TVMFFIPyPushExtraTempPyObject(TVMFFIPyCallContext* ctx,
PyObject* arg) {
+ Py_IncRef(arg);
+ ctx->call_stack->extra_temp_py_objects_stack.emplace_back(arg);
+}
+
//----------------------------------------------------------
// Helpers for MLIR redirection
//----------------------------------------------------------
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index a24395a..8a494fb 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -383,3 +383,22 @@ def test_integral_float_variants_passing() -> None:
y = fecho(FloatProtocol(10))
assert isinstance(y, float)
assert y == 10
+
+
+def test_function_with_value_protocol() -> None:
+ class ValueProtocol:
+ def __init__(self, value: Any) -> None:
+ self.value = value
+
+ def __tvm_ffi_value__(self) -> Any:
+ return self.value
+
+ fecho = tvm_ffi.get_global_func("testing.echo")
+ assert fecho(ValueProtocol(10)) == 10
+ assert tuple(fecho(ValueProtocol([1, 2, 3]))) == (1, 2, 3)
+ assert tuple(fecho(ValueProtocol([1, 2, ValueProtocol(3)]))) == (1, 2, 3)
+ nested_value_protocol = ValueProtocol(ValueProtocol(ValueProtocol(10)))
+ assert fecho(nested_value_protocol) == 10
+
+ nested_value_protocol = ValueProtocol([ValueProtocol(1), ValueProtocol(2),
ValueProtocol(3)])
+ assert tuple(fecho(nested_value_protocol)) == (1, 2, 3)
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index 3b58ccb..fa5c9b3 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -139,6 +139,9 @@ def test_torch_graph() -> None:
device_type = device.dlpack_device_type()
graph = torch.cuda.CUDAGraph()
stream = torch.cuda.Stream(device_id)
+ x = torch.zeros(1, device="cuda")
with tvm_ffi.use_torch_stream(torch.cuda.graph(graph, stream=stream)):
assert torch.cuda.current_stream() == stream
mod.check_stream(device_type, device_id, stream.cuda_stream)
+ # avoid cuda graph no capture warning
+ x = x + 1
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 0551d14..9c938a8 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -18,6 +18,7 @@
from __future__ import annotations
from types import ModuleType
+from typing import Any, NamedTuple
import pytest
@@ -113,6 +114,28 @@ def test_tvm_ffi_tensor_compatible() -> None:
z = fecho(y)
assert z.__chandle__() == x.__chandle__()
+ class MyNamedTuple(NamedTuple):
+ a: MyTensor
+ b: int
+
+ args = MyNamedTuple(a=y, b=1)
+ z = fecho(args)
+ assert z[0].__chandle__() == x.__chandle__()
+ assert z[1] == 1
+
+ class MyCustom:
+ def __init__(self, a: MyTensor, b: int) -> None:
+ self.a = a
+ self.b = b
+
+ def __tvm_ffi_value__(self) -> Any:
+ """Implement __tvm_ffi_value__ protocol."""
+ return (self.a, self.b)
+
+ z = fecho(MyCustom(a=y, b=2))
+ assert z[0].__chandle__() == x.__chandle__()
+ assert z[1] == 2
+
@pytest.mark.skipif(
torch is None or not torch.cuda.is_available() or torch.version.hip is
None,
diff --git a/tests/scripts/benchmark_dlpack.py
b/tests/scripts/benchmark_dlpack.py
index 23db3a8..4798d00 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -33,7 +33,7 @@ Summary of some takeaways:
from __future__ import annotations
import time
-from typing import Any, Callable
+from typing import Any, Callable, NamedTuple
import numpy as np
import torch
@@ -52,6 +52,14 @@ class TestFFITensor:
return self._tensor
+class TestNamedTuple(NamedTuple):
+ """Test FFI NamedTuple."""
+
+ x: torch.Tensor
+ y: torch.Tensor
+ z: torch.Tensor
+
+
def print_speed(name: str, speed: float) -> None:
print(f"{name:<60} {speed} sec/call")
@@ -231,6 +239,20 @@ def bench_tvm_ffi_nop_autodlpack(name: str, x: Any, y:
Any, z: Any, repeat: int)
print_speed(name, speed)
+def bench_tvm_ffi_nop_autodlpack_tuple(name: str, args: TestNamedTuple,
repeat: int) -> None:
+ """Measures overhead of running dlpack via auto convert by directly
+ take torch.Tensor as inputs.
+ """
+ nop = tvm_ffi.get_global_func("testing.nop")
+ nop(args)
+ start = time.time()
+ for i in range(repeat):
+ nop(args)
+ end = time.time()
+ speed = (end - start) / repeat
+ print_speed(name, speed)
+
+
def tvm_ffi_nop_autodlpack_from_torch(
repeat: int, device: str = "cpu", stream: bool = False
) -> None:
@@ -276,18 +298,16 @@ def
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat: int, device: str)
)
-def tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat: int, device: str) ->
None:
+def tvm_ffi_nop_autodlpack_from_test_tensor_namedtuple(repeat: int, device:
str) -> None:
"""Measures overhead of running dlpack via auto convert by directly
take test wrapper as inputs. This effectively measure DLPack exchange in
tvm ffi.
"""
- x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
- y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
- z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
- x = TestFFITensor(x)
- y = TestFFITensor(y)
- z = TestFFITensor(z)
- bench_tvm_ffi_nop_autodlpack(
- f"tvm_ffi.nop.autodlpack(TestFFITensor[{device}])", x, y, z, repeat
+ x = torch.arange(1, device=device)
+ y = torch.arange(1, device=device)
+ z = torch.arange(1, device=device)
+ args = TestNamedTuple(x=x, y=y, z=z)
+ bench_tvm_ffi_nop_autodlpack_tuple(
+ f"tvm_ffi.nop.autodlpack(NamedTuple[{device}])", args, repeat
)
@@ -414,8 +434,8 @@ def main() -> None: # noqa: PLR0915
tvm_ffi_nop_autodlpack_from_numpy(repeat)
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cpu")
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cuda")
- tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat, "cpu")
- tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat, "cuda")
+ tvm_ffi_nop_autodlpack_from_test_tensor_namedtuple(repeat, "cpu")
+ tvm_ffi_nop_autodlpack_from_test_tensor_namedtuple(repeat, "cuda")
tvm_ffi_nop(repeat)
print("-------------------------------")
print("Benchmark x.__dlpack__ overhead")